22# Copyright (c) Microsoft Corporation.
33# Licensed under the MIT License.
44# ------------------------------------
5+ from __future__ import annotations
6+
57from typing import List , Optional
68
79import httpx
1012from kiota_http .middleware import AsyncKiotaTransport
1113from kiota_http .middleware .middleware import BaseMiddleware
1214
13- from ._constants import DEFAULT_CONNECTION_TIMEOUT , DEFAULT_REQUEST_TIMEOUT
14- from ._enums import APIVersion , NationalClouds
1515from .middleware import (
1616 GraphAuthorizationHandler ,
1717 GraphMiddlewarePipeline ,
@@ -26,54 +26,33 @@ class GraphClientFactory(KiotaClientFactory):
2626 pipeline of graph specific middleware.
2727 """
2828
29- def __init__ (
30- self ,
31- api_version : APIVersion ,
32- base_url : NationalClouds ,
33- timeout : httpx .Timeout ,
34- client : Optional [httpx .AsyncClient ],
35- ):
36- """Class constructor accepts a user provided client object and kwargs to configure the
37- request handling behaviour of the client
38-
39- Args:
40- api_version (APIVersion): The Microsoft Graph API version to be used, for example
41- `APIVersion.v1` (default). This value is used in setting
42- the base url for all requests for that session.
43- base_url (NationalClouds): a supported Microsoft Graph cloud endpoint.
44- timeout (httpx.Timeout):Default connection and read timeout values for all session
45- requests.Specify a tuple in the form of httpx.Timeout(
46- REQUEST_TIMEOUT, connect=CONNECTION_TIMEOUT),
47- client (Optional[httpx.AsyncClient]): A custom AsynClient instance from the
48- python httpx library
49- """
50- self .api_version = api_version
51- self .base_url = base_url
52- self .timeout = timeout
53- self .client = client
54-
29+ @staticmethod
5530 def create_with_default_middleware (
56- self , token_provider : AccessTokenProvider
31+ client : httpx .AsyncClient ,
32+ token_provider : Optional [AccessTokenProvider ] = None
5733 ) -> httpx .AsyncClient :
5834 """Constructs native HTTP AsyncClient(httpx.AsyncClient) instances configured with
5935 a custom transport loaded with a default pipeline of middleware.
6036 Returns:
6137 httpx.AsycClient: An instance of the AsyncClient object
6238 """
63- if not self .client :
64- self .client = httpx .AsyncClient (
65- base_url = self ._get_base_url (), timeout = self .timeout , http2 = True
66- )
67- current_transport = self .client ._transport
68- middleware = self ._get_default_middleware (token_provider , current_transport )
39+ current_transport = client ._transport
40+ middleware = GraphClientFactory ._get_common_middleware ()
41+ if token_provider :
42+ middleware .insert (0 , GraphAuthorizationHandler (token_provider ))
43+
44+ middleware_pipeline = GraphClientFactory ._create_middleware_pipeline (
45+ middleware , current_transport
46+ )
6947
70- self . client ._transport = AsyncKiotaTransport (
71- transport = current_transport , middleware = middleware
48+ client ._transport = AsyncKiotaTransport (
49+ transport = current_transport , middleware = middleware_pipeline
7250 )
73- return self . client
51+ return client
7452
53+ @staticmethod
7554 def create_with_custom_middleware (
76- self , middleware : Optional [List [BaseMiddleware ]]
55+ client : httpx . AsyncClient , middleware : Optional [List [BaseMiddleware ]]
7756 ) -> httpx .AsyncClient :
7857 """Applies a custom middleware chain to the HTTP Client
7958
@@ -82,36 +61,34 @@ def create_with_custom_middleware(
8261 a middleware pipeline. The middleware should be arranged in the order in which they will
8362 modify the request.
8463 """
85- if not self .client :
86- self .client = httpx .AsyncClient (
87- base_url = self ._get_base_url (), timeout = self .timeout , http2 = True
88- )
89- current_transport = self .client ._transport
64+ current_transport = client ._transport
65+ middleware_pipeline = GraphClientFactory ._create_middleware_pipeline (
66+ middleware , current_transport
67+ )
9068
91- self . client ._transport = AsyncKiotaTransport (
92- transport = current_transport , middleware = middleware
69+ client ._transport = AsyncKiotaTransport (
70+ transport = current_transport , middleware = middleware_pipeline
9371 )
94- return self .client
72+ return client
73+
74+ @staticmethod
75+ def _get_common_middleware () -> List [BaseMiddleware ]:
76+ """
77+ Helper method that returns a list of cross cutting middleware
78+ """
79+ middleware = [GraphRedirectHandler (), GraphRetryHandler (), GraphTelemetryHandler ()]
9580
96- def _get_base_url (self ):
97- """Helper method to set the base url"""
98- base_url = self .base_url + '/' + self .api_version
99- return base_url
81+ return middleware
10082
101- def _get_default_middleware (
102- self , token_provider : AccessTokenProvider , transport : httpx .AsyncBaseTransport
83+ @staticmethod
84+ def _create_middleware_pipeline (
85+ middleware : Optional [List [BaseMiddleware ]], transport : httpx .AsyncBaseTransport
10386 ) -> GraphMiddlewarePipeline :
10487 """
10588 Helper method that constructs a middleware_pipeline with the specified middleware
10689 """
10790 middleware_pipeline = GraphMiddlewarePipeline (transport )
108- middleware = [
109- GraphAuthorizationHandler (token_provider ),
110- GraphRedirectHandler (),
111- GraphRetryHandler (),
112- GraphTelemetryHandler ()
113- ]
114- for ware in middleware :
115- middleware_pipeline .add_middleware (ware )
116-
91+ if middleware :
92+ for ware in middleware :
93+ middleware_pipeline .add_middleware (ware )
11794 return middleware_pipeline
0 commit comments