Skip to content

Commit 14111ff

Browse files
Add typing for Defender (#828)
* Apply typing * Add typing to MDATP driver * FIx incomplete typing in OData driver * Make cloud an explicit parameter * Return an empty dataframe when no results are returned * Return a string to be consistent with parent's definition * Explode MDATPDriver init's kwargs --------- Co-authored-by: Ian Hellen <ianhelle@microsoft.com>
1 parent 973e48d commit 14111ff

File tree

2 files changed

+51
-27
lines changed

2 files changed

+51
-27
lines changed

msticpy/data/drivers/mdatp_driver.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
# license information.
55
# --------------------------------------------------------------------------
66
"""MDATP OData Driver class."""
7+
from __future__ import annotations
8+
79
from dataclasses import dataclass, field
8-
from typing import Any, Dict, List, Optional, Union
10+
from typing import Any, ClassVar, Iterable
911

1012
import pandas as pd
13+
from typing_extensions import Self
1114

1215
from ..._version import VERSION
1316
from ...auth.azure_auth_core import AzureCloudConfig
@@ -37,10 +40,10 @@ class M365DConfiguration:
3740
api_version: str
3841
api_endpoint: str
3942
api_uri: str
40-
scopes: List[str]
43+
scopes: list[str]
4144
oauth_v2: bool = field(init=False)
4245

43-
def __post_init__(self):
46+
def __post_init__(self: Self) -> None:
4447
"""Determine if the selected API supports Entra ID OAuth v2.0.
4548
4649
This is important because the fields in the request body
@@ -56,12 +59,20 @@ def __post_init__(self):
5659
class MDATPDriver(OData):
5760
"""KqlDriver class to retrieve date from MS Defender APIs."""
5861

59-
CONFIG_NAME = "MicrosoftDefender"
60-
_ALT_CONFIG_NAMES = ["MDATPApp"]
62+
CONFIG_NAME: ClassVar[str] = "MicrosoftDefender"
63+
_ALT_CONFIG_NAMES: ClassVar[Iterable[str]] = ["MDATPApp"]
6164

6265
def __init__(
63-
self, connection_str: Optional[str] = None, instance: str = "Default", **kwargs
64-
):
66+
self: MDATPDriver,
67+
connection_str: str | None = None,
68+
instance: str = "Default",
69+
*,
70+
cloud: str | None = None,
71+
auth_type: str = "interactive",
72+
debug: bool = False,
73+
max_threads: int = 4,
74+
**kwargs,
75+
) -> None:
6576
"""
6677
Instantiate MSDefenderDriver and optionally connect.
6778
@@ -71,31 +82,39 @@ def __init__(
7182
Connection string
7283
instance : str, optional
7384
The instance name from config to use
85+
cloud: str
86+
Name of the Azure Cloud to connect to.
7487
7588
"""
76-
super().__init__(**kwargs)
89+
super().__init__(
90+
debug=debug,
91+
max_threads=max_threads,
92+
**kwargs,
93+
)
7794

78-
cs_dict = _get_driver_settings(
95+
cs_dict: dict[str, str] = _get_driver_settings(
7996
self.CONFIG_NAME, self._ALT_CONFIG_NAMES, instance
8097
)
8198

82-
self.cloud = cs_dict.pop("cloud", "global")
83-
if "cloud" in kwargs and kwargs["cloud"]:
84-
self.cloud = kwargs["cloud"]
99+
self.cloud: str = cs_dict.pop("cloud", "global")
100+
if cloud:
101+
self.cloud = cloud
85102

86-
m365d_params = _select_api(self.data_environment, self.cloud)
103+
m365d_params: M365DConfiguration = _select_api(
104+
self.data_environment, self.cloud
105+
)
87106
self._m365d_params: M365DConfiguration = m365d_params
88107
self.oauth_url = m365d_params.login_uri
89108
self.api_root = m365d_params.resource_uri
90109
self.api_ver = m365d_params.api_version
91-
self.api_suffix = m365d_params.api_endpoint
110+
self.api_suffix: str = m365d_params.api_endpoint
92111
self.scopes = m365d_params.scopes
93112

94113
self.add_query_filter(
95114
"data_environments", ("MDE", "M365D", "MDATP", "M365DGraph", "GraphHunting")
96115
)
97116

98-
self.req_body: Dict[str, Any] = {}
117+
self.req_body: dict[str, Any] = {}
99118
if "username" in cs_dict:
100119
delegated_auth = True
101120

@@ -111,13 +130,16 @@ def __init__(
111130
self.connect(
112131
connection_str,
113132
delegated_auth=delegated_auth,
114-
auth_type=kwargs.get("auth_type", "interactive"),
133+
auth_type=auth_type,
115134
location=cs_dict.get("location", "token_cache.bin"),
116135
)
117136

118137
def query(
119-
self, query: str, query_source: Optional[QuerySource] = None, **kwargs
120-
) -> Union[pd.DataFrame, Any]:
138+
self: Self,
139+
query: str,
140+
query_source: QuerySource | None = None,
141+
**kwargs,
142+
) -> pd.DataFrame | str | None:
121143
"""
122144
Execute query string and return DataFrame of results.
123145
@@ -145,7 +167,7 @@ def query(
145167
return data
146168

147169
if self.data_environment == DataEnvironment.M365DGraph:
148-
date_fields = [
170+
date_fields: list[str] = [
149171
field["name"]
150172
for field in response["schema"]
151173
if field["type"] == "DateTime"
@@ -158,10 +180,10 @@ def query(
158180
]
159181
data = ensure_df_datetimes(data, columns=date_fields)
160182
return data
161-
return response
183+
return str(response)
162184

163185

164-
def _select_api(data_environment, cloud) -> M365DConfiguration:
186+
def _select_api(data_environment: DataEnvironment, cloud: str) -> M365DConfiguration:
165187
# pylint: disable=line-too-long
166188
"""Return API and login URIs for selected provider type.
167189
@@ -177,11 +199,13 @@ def _select_api(data_environment, cloud) -> M365DConfiguration:
177199
# pylint: enable=line-too-long
178200
if data_environment == DataEnvironment.M365DGraph:
179201
az_cloud_config = AzureCloudConfig(cloud=cloud)
180-
login_uri = f"{az_cloud_config.authority_uri}{{tenantId}}/oauth2/v2.0/token"
181-
resource_uri = az_cloud_config.endpoints["microsoftGraphResourceId"]
202+
login_uri: str = (
203+
f"{az_cloud_config.authority_uri}{{tenantId}}/oauth2/v2.0/token"
204+
)
205+
resource_uri: str = az_cloud_config.endpoints["microsoftGraphResourceId"]
182206
api_version = "v1.0"
183207
api_endpoint = "/security/runHuntingQuery"
184-
scopes = [f"{resource_uri}ThreatHunting.Read.All"]
208+
scopes: list[str] = [f"{resource_uri}ThreatHunting.Read.All"]
185209

186210
elif data_environment == DataEnvironment.M365D:
187211
login_uri = f"{get_m365d_login_endpoint(cloud)}{{tenantId}}/oauth2/token"
@@ -197,7 +221,7 @@ def _select_api(data_environment, cloud) -> M365DConfiguration:
197221
api_endpoint = "/advancedqueries/run"
198222
scopes = [f"{resource_uri}AdvancedQuery.Read"]
199223

200-
api_uri = f"{resource_uri}{api_version}{api_endpoint}"
224+
api_uri: str = f"{resource_uri}{api_version}{api_endpoint}"
201225

202226
return M365DConfiguration(
203227
login_uri=login_uri,

msticpy/data/drivers/odata_driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class OData(DriverBase):
4949
"""Parent class to retrieve date from an oauth based API."""
5050

5151
CONFIG_NAME: ClassVar[str] = ""
52-
_ALT_CONFIG_NAMES: Iterable[str] = []
52+
_ALT_CONFIG_NAMES: ClassVar[Iterable[str]] = []
5353

5454
def __init__(
5555
self: OData,
@@ -377,7 +377,7 @@ def query_with_results(
377377

378378
if not result:
379379
LOGGER.warning("Query did not return any results.")
380-
return None, json_response
380+
return pd.DataFrame(), json_response
381381
return pd.json_normalize(result), json_response
382382

383383
# pylint: enable=too-many-branches

0 commit comments

Comments
 (0)