|
24 | 24 | }.to_json |
25 | 25 | end |
26 | 26 |
|
| 27 | + let(:model_id) { "anthropic.claude-3-sonnet-20240229-v1:0" } |
| 28 | + |
27 | 29 | before do |
28 | 30 | response_object = double("response_object") |
29 | 31 | allow(response_object).to receive(:body).and_return(StringIO.new(response)) |
30 | 32 | allow(subject.client).to receive(:invoke_model) |
31 | 33 | .with(matching( |
32 | | - model_id: "anthropic.claude-3-sonnet-20240229-v1:0", |
| 34 | + model_id:, |
33 | 35 | body: { |
34 | 36 | messages: [{role: "user", content: "What is the capital of France?"}], |
35 | 37 | stop_sequences: ["stop"], |
|
46 | 48 | expect( |
47 | 49 | subject.chat( |
48 | 50 | messages: [{role: "user", content: "What is the capital of France?"}], |
49 | | - model: "anthropic.claude-3-sonnet-20240229-v1:0", |
| 51 | + model: model_id, |
50 | 52 | stop_sequences: ["stop"] |
51 | 53 | ).chat_completion |
52 | 54 | ).to eq("The capital of France is Paris.") |
53 | 55 | end |
54 | 56 |
|
| 57 | + context "without default model" do |
| 58 | + let(:model_id) { "anthropic.claude-v2" } |
| 59 | + |
| 60 | + it "returns a completion" do |
| 61 | + expect( |
| 62 | + subject.chat( |
| 63 | + messages: [{role: "user", content: "What is the capital of France?"}], |
| 64 | + stop_sequences: ["stop"] |
| 65 | + ).chat_completion |
| 66 | + ).to eq("The capital of France is Paris.") |
| 67 | + end |
| 68 | + end |
| 69 | + |
55 | 70 | context "with streaming" do |
56 | 71 | let(:chunks) do |
57 | 72 | [ |
|
201 | 216 | end |
202 | 217 |
|
203 | 218 | context "with ai21 provider" do |
204 | | - let(:subject) { described_class.new(completion_model: "ai21.j2-ultra-v1") } |
| 219 | + let(:subject) { described_class.new(default_options: {completion_model_name: "ai21.j2-ultra-v1"}) } |
205 | 220 |
|
206 | 221 | let(:response) do |
207 | 222 | StringIO.new("{\"completions\":[{\"data\":{\"text\":\"\\nWhat is the meaning of life? What is the meaning of life?\\nWhat is the meaning\"}}]}") |
|
308 | 323 | context "with custom default_options" do |
309 | 324 | let(:subject) { |
310 | 325 | described_class.new( |
311 | | - completion_model: "ai21.j2-ultra-v1", |
312 | | - default_options: {max_tokens_to_sample: 100, temperature: 0.7} |
| 326 | + default_options: { |
| 327 | + completion_model_name: "ai21.j2-ultra-v1", |
| 328 | + max_tokens_to_sample: 100, |
| 329 | + temperature: 0.7 |
| 330 | + } |
313 | 331 | ) |
314 | 332 | } |
315 | 333 | let(:response_object) { double("response_object") } |
|
363 | 381 | end |
364 | 382 |
|
365 | 383 | context "with cohere provider" do |
366 | | - let(:subject) { described_class.new(completion_model: "cohere.command-text-v14") } |
| 384 | + let(:subject) { described_class.new(default_options: {completion_model_name: "cohere.command-text-v14"}) } |
367 | 385 |
|
368 | 386 | let(:response) do |
369 | 387 | StringIO.new("{\"generations\":[{\"text\":\"\\nWhat is the meaning of life? What is the meaning of life?\\nWhat is the meaning\"}]}") |
|
424 | 442 | context "with custom default_options" do |
425 | 443 | let(:subject) { |
426 | 444 | described_class.new( |
427 | | - completion_model: "cohere.command-text-v14", |
428 | | - default_options: {max_tokens_to_sample: 100, temperature: 0.7} |
| 445 | + default_options: { |
| 446 | + completion_model_name: "cohere.command-text-v14", |
| 447 | + max_tokens_to_sample: 100, |
| 448 | + temperature: 0.7 |
| 449 | + } |
429 | 450 | ) |
430 | 451 | } |
431 | 452 | let(:response_object) { double("response_object") } |
|
456 | 477 | end |
457 | 478 |
|
458 | 479 | context "with unsupported provider" do |
459 | | - let(:subject) { described_class.new(completion_model: "unsupported.provider") } |
| 480 | + let(:subject) { described_class.new(default_options: {completion_model_name: "unsupported.provider"}) } |
460 | 481 |
|
461 | 482 | it "raises an exception" do |
462 | 483 | expect { subject.complete(prompt: "Hello World") }.to raise_error("Completion provider unsupported is not supported.") |
|
492 | 513 | end |
493 | 514 |
|
494 | 515 | context "with cohere provider" do |
495 | | - let(:subject) { described_class.new(embedding_model: "cohere.embed-multilingual-v3") } |
| 516 | + let(:subject) { described_class.new(default_options: {embeddings_model_name: "cohere.embed-multilingual-v3"}) } |
496 | 517 |
|
497 | 518 | let(:response) do |
498 | 519 | StringIO.new("{\"embeddings\":[[0.1,0.2,0.3,0.4,0.5]]}") |
|
522 | 543 | end |
523 | 544 |
|
524 | 545 | context "with unsupported provider" do |
525 | | - let(:subject) { described_class.new(embedding_model: "unsupported.provider") } |
| 546 | + let(:subject) { described_class.new(default_options: {embeddings_model_name: "unsupported.provider"}) } |
526 | 547 |
|
527 | 548 | it "raises an exception" do |
528 | 549 | expect { subject.embed(text: "Hello World") }.to raise_error("Completion provider unsupported is not supported.") |
|
0 commit comments