Skip to content

Commit da88e9e

Browse files
feat(api): manual updates
1 parent 292da2b commit da88e9e

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

src/gradient/_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,16 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
291291
if isinstance(custom_headers.get("Authorization"), Omit):
292292
return
293293

294+
if self.model_access_key and headers.get("Authorization"):
295+
return
296+
if isinstance(custom_headers.get("Authorization"), Omit):
297+
return
298+
299+
if self.agent_access_key and headers.get("Authorization"):
300+
return
301+
if isinstance(custom_headers.get("Authorization"), Omit):
302+
return
303+
294304
raise TypeError(
295305
'"Could not resolve authentication method. Expected access_token, agent_access_key, or model_access_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
296306
)
@@ -614,6 +624,16 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
614624
if isinstance(custom_headers.get("Authorization"), Omit):
615625
return
616626

627+
if self.model_access_key and headers.get("Authorization"):
628+
return
629+
if isinstance(custom_headers.get("Authorization"), Omit):
630+
return
631+
632+
if self.agent_access_key and headers.get("Authorization"):
633+
return
634+
if isinstance(custom_headers.get("Authorization"), Omit):
635+
return
636+
617637
raise TypeError(
618638
'"Could not resolve authentication method. Expected access_token, agent_access_key, or model_access_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
619639
)

tests/conftest.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None:
4646
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
4747

4848
access_token = "My Access Token"
49-
model_access_key = "My Model Access Key"
50-
agent_access_key = "My Agent Access Key"
5149

5250

5351
@pytest.fixture(scope="session")
@@ -56,13 +54,7 @@ def client(request: FixtureRequest) -> Iterator[Gradient]:
5654
if not isinstance(strict, bool):
5755
raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}")
5856

59-
with Gradient(
60-
base_url=base_url,
61-
access_token=access_token,
62-
model_access_key=model_access_key,
63-
agent_access_key=agent_access_key,
64-
_strict_response_validation=strict,
65-
) as client:
57+
with Gradient(base_url=base_url, access_token=access_token, _strict_response_validation=strict) as client:
6658
yield client
6759

6860

@@ -87,11 +79,6 @@ async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncGradient]:
8779
raise TypeError(f"Unexpected fixture parameter type {type(param)}, expected bool or dict")
8880

8981
async with AsyncGradient(
90-
base_url=base_url,
91-
access_token=access_token,
92-
model_access_key=model_access_key,
93-
agent_access_key=agent_access_key,
94-
_strict_response_validation=strict,
95-
http_client=http_client,
82+
base_url=base_url, access_token=access_token, _strict_response_validation=strict, http_client=http_client
9683
) as client:
9784
yield client

0 commit comments

Comments
 (0)