44# license information.
55# --------------------------------------------------------------------------
66"""MDATP OData Driver class."""
7+ from __future__ import annotations
8+
79from dataclasses import dataclass , field
8- from typing import Any , Dict , List , Optional , Union
10+ from typing import Any , ClassVar , Iterable
911
1012import pandas as pd
13+ from typing_extensions import Self
1114
1215from ..._version import VERSION
1316from ...auth .azure_auth_core import AzureCloudConfig
@@ -37,10 +40,10 @@ class M365DConfiguration:
3740 api_version : str
3841 api_endpoint : str
3942 api_uri : str
40- scopes : List [str ]
43+ scopes : list [str ]
4144 oauth_v2 : bool = field (init = False )
4245
43- def __post_init__ (self ) :
46+ def __post_init__ (self : Self ) -> None :
4447 """Determine if the selected API supports Entra ID OAuth v2.0.
4548
4649 This is important because the fields in the request body
@@ -56,12 +59,20 @@ def __post_init__(self):
5659class MDATPDriver (OData ):
5760 """KqlDriver class to retrieve date from MS Defender APIs."""
5861
59- CONFIG_NAME = "MicrosoftDefender"
60- _ALT_CONFIG_NAMES = ["MDATPApp" ]
62+ CONFIG_NAME : ClassVar [ str ] = "MicrosoftDefender"
63+ _ALT_CONFIG_NAMES : ClassVar [ Iterable [ str ]] = ["MDATPApp" ]
6164
6265 def __init__ (
63- self , connection_str : Optional [str ] = None , instance : str = "Default" , ** kwargs
64- ):
66+ self : MDATPDriver ,
67+ connection_str : str | None = None ,
68+ instance : str = "Default" ,
69+ * ,
70+ cloud : str | None = None ,
71+ auth_type : str = "interactive" ,
72+ debug : bool = False ,
73+ max_threads : int = 4 ,
74+ ** kwargs ,
75+ ) -> None :
6576 """
6677 Instantiate MSDefenderDriver and optionally connect.
6778
@@ -71,31 +82,39 @@ def __init__(
7182 Connection string
7283 instance : str, optional
7384 The instance name from config to use
85+ cloud: str
86+ Name of the Azure Cloud to connect to.
7487
7588 """
76- super ().__init__ (** kwargs )
89+ super ().__init__ (
90+ debug = debug ,
91+ max_threads = max_threads ,
92+ ** kwargs ,
93+ )
7794
78- cs_dict = _get_driver_settings (
95+ cs_dict : dict [ str , str ] = _get_driver_settings (
7996 self .CONFIG_NAME , self ._ALT_CONFIG_NAMES , instance
8097 )
8198
82- self .cloud = cs_dict .pop ("cloud" , "global" )
83- if " cloud" in kwargs and kwargs [ "cloud" ] :
84- self .cloud = kwargs [ " cloud" ]
99+ self .cloud : str = cs_dict .pop ("cloud" , "global" )
100+ if cloud :
101+ self .cloud = cloud
85102
86- m365d_params = _select_api (self .data_environment , self .cloud )
103+ m365d_params : M365DConfiguration = _select_api (
104+ self .data_environment , self .cloud
105+ )
87106 self ._m365d_params : M365DConfiguration = m365d_params
88107 self .oauth_url = m365d_params .login_uri
89108 self .api_root = m365d_params .resource_uri
90109 self .api_ver = m365d_params .api_version
91- self .api_suffix = m365d_params .api_endpoint
110+ self .api_suffix : str = m365d_params .api_endpoint
92111 self .scopes = m365d_params .scopes
93112
94113 self .add_query_filter (
95114 "data_environments" , ("MDE" , "M365D" , "MDATP" , "M365DGraph" , "GraphHunting" )
96115 )
97116
98- self .req_body : Dict [str , Any ] = {}
117+ self .req_body : dict [str , Any ] = {}
99118 if "username" in cs_dict :
100119 delegated_auth = True
101120
@@ -111,13 +130,16 @@ def __init__(
111130 self .connect (
112131 connection_str ,
113132 delegated_auth = delegated_auth ,
114- auth_type = kwargs . get ( " auth_type" , "interactive" ) ,
133+ auth_type = auth_type ,
115134 location = cs_dict .get ("location" , "token_cache.bin" ),
116135 )
117136
118137 def query (
119- self , query : str , query_source : Optional [QuerySource ] = None , ** kwargs
120- ) -> Union [pd .DataFrame , Any ]:
138+ self : Self ,
139+ query : str ,
140+ query_source : QuerySource | None = None ,
141+ ** kwargs ,
142+ ) -> pd .DataFrame | str | None :
121143 """
122144 Execute query string and return DataFrame of results.
123145
@@ -145,7 +167,7 @@ def query(
145167 return data
146168
147169 if self .data_environment == DataEnvironment .M365DGraph :
148- date_fields = [
170+ date_fields : list [ str ] = [
149171 field ["name" ]
150172 for field in response ["schema" ]
151173 if field ["type" ] == "DateTime"
@@ -158,10 +180,10 @@ def query(
158180 ]
159181 data = ensure_df_datetimes (data , columns = date_fields )
160182 return data
161- return response
183+ return str ( response )
162184
163185
164- def _select_api (data_environment , cloud ) -> M365DConfiguration :
186+ def _select_api (data_environment : DataEnvironment , cloud : str ) -> M365DConfiguration :
165187 # pylint: disable=line-too-long
166188 """Return API and login URIs for selected provider type.
167189
@@ -177,11 +199,13 @@ def _select_api(data_environment, cloud) -> M365DConfiguration:
177199 # pylint: enable=line-too-long
178200 if data_environment == DataEnvironment .M365DGraph :
179201 az_cloud_config = AzureCloudConfig (cloud = cloud )
180- login_uri = f"{ az_cloud_config .authority_uri } {{tenantId}}/oauth2/v2.0/token"
181- resource_uri = az_cloud_config .endpoints ["microsoftGraphResourceId" ]
202+ login_uri : str = (
203+ f"{ az_cloud_config .authority_uri } {{tenantId}}/oauth2/v2.0/token"
204+ )
205+ resource_uri : str = az_cloud_config .endpoints ["microsoftGraphResourceId" ]
182206 api_version = "v1.0"
183207 api_endpoint = "/security/runHuntingQuery"
184- scopes = [f"{ resource_uri } ThreatHunting.Read.All" ]
208+ scopes : list [ str ] = [f"{ resource_uri } ThreatHunting.Read.All" ]
185209
186210 elif data_environment == DataEnvironment .M365D :
187211 login_uri = f"{ get_m365d_login_endpoint (cloud )} {{tenantId}}/oauth2/token"
@@ -197,7 +221,7 @@ def _select_api(data_environment, cloud) -> M365DConfiguration:
197221 api_endpoint = "/advancedqueries/run"
198222 scopes = [f"{ resource_uri } AdvancedQuery.Read" ]
199223
200- api_uri = f"{ resource_uri } { api_version } { api_endpoint } "
224+ api_uri : str = f"{ resource_uri } { api_version } { api_endpoint } "
201225
202226 return M365DConfiguration (
203227 login_uri = login_uri ,
0 commit comments