Skip to content

Commit 2f48846

Browse files
authored
Feature: Allow passing more parameters to cohere embed models (#36)
1 parent eedc39e commit 2f48846

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

src/Schemas/Cohere/CohereEmbeddingsHandler.php

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
namespace Prism\Bedrock\Schemas\Cohere;
44

55
use Illuminate\Http\Client\Response;
6+
use Illuminate\Support\Arr;
67
use Prism\Bedrock\Contracts\BedrockEmbeddingsHandler;
78
use Prism\Prism\Embeddings\Request;
89
use Prism\Prism\Embeddings\Response as EmbeddingsResponse;
@@ -39,9 +40,15 @@ public static function buildPayload(Request $request): array
3940
{
4041
return array_filter([
4142
'texts' => $request->inputs(),
42-
'input_type' => 'search_document', // TODO: Need to PR providerOptions onto embeddings request to allow override.
43-
'truncate' => null, // TODO: Need to PR providerOptions onto embeddings request to allow override. Default for now.
44-
'embedding_types' => null, // TODO: Need to PR providerOptions onto embeddings request to allow override. Default for now.
43+
'input_type' => 'search_document',
44+
'truncate' => null,
45+
'embedding_types' => null,
46+
...Arr::only($request->providerOptions(), [
47+
'input_type',
48+
'embedding_types',
49+
'truncate',
50+
'output_dimension',
51+
]),
4552
]);
4653
}
4754

tests/Schemas/Cohere/CohereEmbeddingsTest.php

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
namespace Tests\Schemas\Cohere;
66

7+
use Illuminate\Support\Facades\Http;
78
use Prism\Prism\Prism;
89
use Prism\Prism\ValueObjects\Embedding;
910
use Tests\Fixtures\FixtureResponse;
@@ -63,3 +64,38 @@
6364
expect($response->embeddings[1]->embedding)->toEqual($embeddings[1]->embedding);
6465
expect($response->usage->tokens)->toBe(1);
6566
});
67+
68+
it('can set request params', function (): void {
69+
FixtureResponse::fakeResponseSequence('invoke', 'cohere/generate-embeddings-from-input', [
70+
'X-Amzn-Bedrock-Input-Token-Count' => 4,
71+
]);
72+
73+
$response = Prism::embeddings()
74+
->using('bedrock', 'cohere.embed-english-v3')
75+
->withProviderOptions([
76+
'input_type' => 'search_query',
77+
'truncate' => 'RIGHT',
78+
'embedding_types' => ['sparse', 'dense'],
79+
'output_dimension' => 1536,
80+
'some_other_option' => 'should be filtered out',
81+
])
82+
->fromInput('Hello, world!')
83+
->asEmbeddings();
84+
85+
$embeddings = json_decode(file_get_contents('tests/Fixtures/cohere/generate-embeddings-from-input-1.json'), true);
86+
$embeddings = array_map(fn (array $item): Embedding => Embedding::fromArray($item), data_get($embeddings, 'embeddings'));
87+
88+
Http::assertSent(function ($request): bool {
89+
$body = $request->data();
90+
91+
return $body['input_type'] === 'search_query'
92+
&& $body['truncate'] === 'RIGHT'
93+
&& $body['embedding_types'] === ['sparse', 'dense']
94+
&& $body['output_dimension'] === 1536
95+
&& ! array_key_exists('some_other_option', $body);
96+
});
97+
98+
expect($response->embeddings)->toBeArray();
99+
expect($response->embeddings[0]->embedding)->toEqual($embeddings[0]->embedding);
100+
expect($response->usage->tokens)->toBe(4);
101+
});

0 commit comments

Comments
 (0)