Skip to content
Merged
74 changes: 49 additions & 25 deletions msticpy/data/drivers/mdatp_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# license information.
# --------------------------------------------------------------------------
"""MDATP OData Driver class."""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from typing import Any, ClassVar, Iterable

import pandas as pd
from typing_extensions import Self

from ..._version import VERSION
from ...auth.azure_auth_core import AzureCloudConfig
Expand Down Expand Up @@ -37,10 +40,10 @@ class M365DConfiguration:
api_version: str
api_endpoint: str
api_uri: str
scopes: List[str]
scopes: list[str]
oauth_v2: bool = field(init=False)

def __post_init__(self):
def __post_init__(self: Self) -> None:
"""Determine if the selected API supports Entra ID OAuth v2.0.

This is important because the fields in the request body
Expand All @@ -56,12 +59,20 @@ def __post_init__(self):
class MDATPDriver(OData):
"""KqlDriver class to retrieve date from MS Defender APIs."""

CONFIG_NAME = "MicrosoftDefender"
_ALT_CONFIG_NAMES = ["MDATPApp"]
CONFIG_NAME: ClassVar[str] = "MicrosoftDefender"
_ALT_CONFIG_NAMES: ClassVar[Iterable[str]] = ["MDATPApp"]

def __init__(
self, connection_str: Optional[str] = None, instance: str = "Default", **kwargs
):
self: MDATPDriver,
connection_str: str | None = None,
instance: str = "Default",
*,
cloud: str | None = None,
auth_type: str = "interactive",
debug: bool = False,
max_threads: int = 4,
**kwargs,
) -> None:
"""
Instantiate MSDefenderDriver and optionally connect.

Expand All @@ -71,31 +82,39 @@ def __init__(
Connection string
instance : str, optional
The instance name from config to use
cloud: str
Name of the Azure Cloud to connect to.

"""
super().__init__(**kwargs)
super().__init__(
debug=debug,
max_threads=max_threads,
**kwargs,
)

cs_dict = _get_driver_settings(
cs_dict: dict[str, str] = _get_driver_settings(
self.CONFIG_NAME, self._ALT_CONFIG_NAMES, instance
)

self.cloud = cs_dict.pop("cloud", "global")
if "cloud" in kwargs and kwargs["cloud"]:
self.cloud = kwargs["cloud"]
self.cloud: str = cs_dict.pop("cloud", "global")
if cloud:
self.cloud = cloud

m365d_params = _select_api(self.data_environment, self.cloud)
m365d_params: M365DConfiguration = _select_api(
self.data_environment, self.cloud
)
self._m365d_params: M365DConfiguration = m365d_params
self.oauth_url = m365d_params.login_uri
self.api_root = m365d_params.resource_uri
self.api_ver = m365d_params.api_version
self.api_suffix = m365d_params.api_endpoint
self.api_suffix: str = m365d_params.api_endpoint
self.scopes = m365d_params.scopes

self.add_query_filter(
"data_environments", ("MDE", "M365D", "MDATP", "M365DGraph", "GraphHunting")
)

self.req_body: Dict[str, Any] = {}
self.req_body: dict[str, Any] = {}
if "username" in cs_dict:
delegated_auth = True

Expand All @@ -111,13 +130,16 @@ def __init__(
self.connect(
connection_str,
delegated_auth=delegated_auth,
auth_type=kwargs.get("auth_type", "interactive"),
auth_type=auth_type,
location=cs_dict.get("location", "token_cache.bin"),
)

def query(
self, query: str, query_source: Optional[QuerySource] = None, **kwargs
) -> Union[pd.DataFrame, Any]:
self: Self,
query: str,
query_source: QuerySource | None = None,
**kwargs,
) -> pd.DataFrame | str | None:
"""
Execute query string and return DataFrame of results.

Expand Down Expand Up @@ -145,7 +167,7 @@ def query(
return data

if self.data_environment == DataEnvironment.M365DGraph:
date_fields = [
date_fields: list[str] = [
field["name"]
for field in response["schema"]
if field["type"] == "DateTime"
Expand All @@ -158,10 +180,10 @@ def query(
]
data = ensure_df_datetimes(data, columns=date_fields)
return data
return response
return str(response)


def _select_api(data_environment, cloud) -> M365DConfiguration:
def _select_api(data_environment: DataEnvironment, cloud: str) -> M365DConfiguration:
# pylint: disable=line-too-long
"""Return API and login URIs for selected provider type.

Expand All @@ -177,11 +199,13 @@ def _select_api(data_environment, cloud) -> M365DConfiguration:
# pylint: enable=line-too-long
if data_environment == DataEnvironment.M365DGraph:
az_cloud_config = AzureCloudConfig(cloud=cloud)
login_uri = f"{az_cloud_config.authority_uri}{{tenantId}}/oauth2/v2.0/token"
resource_uri = az_cloud_config.endpoints["microsoftGraphResourceId"]
login_uri: str = (
f"{az_cloud_config.authority_uri}{{tenantId}}/oauth2/v2.0/token"
)
resource_uri: str = az_cloud_config.endpoints["microsoftGraphResourceId"]
api_version = "v1.0"
api_endpoint = "/security/runHuntingQuery"
scopes = [f"{resource_uri}ThreatHunting.Read.All"]
scopes: list[str] = [f"{resource_uri}ThreatHunting.Read.All"]

elif data_environment == DataEnvironment.M365D:
login_uri = f"{get_m365d_login_endpoint(cloud)}{{tenantId}}/oauth2/token"
Expand All @@ -197,7 +221,7 @@ def _select_api(data_environment, cloud) -> M365DConfiguration:
api_endpoint = "/advancedqueries/run"
scopes = [f"{resource_uri}AdvancedQuery.Read"]

api_uri = f"{resource_uri}{api_version}{api_endpoint}"
api_uri: str = f"{resource_uri}{api_version}{api_endpoint}"

return M365DConfiguration(
login_uri=login_uri,
Expand Down
4 changes: 2 additions & 2 deletions msticpy/data/drivers/odata_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class OData(DriverBase):
"""Parent class to retrieve date from an oauth based API."""

CONFIG_NAME: ClassVar[str] = ""
_ALT_CONFIG_NAMES: Iterable[str] = []
_ALT_CONFIG_NAMES: ClassVar[Iterable[str]] = []

def __init__(
self: OData,
Expand Down Expand Up @@ -377,7 +377,7 @@ def query_with_results(

if not result:
LOGGER.warning("Query did not return any results.")
return None, json_response
return pd.DataFrame(), json_response
return pd.json_normalize(result), json_response

# pylint: enable=too-many-branches
Expand Down