Skip to content

Commit d5be357

Browse files
author
Joanna Grycz
committed
Add check for request structure
1 parent e5cc8c6 commit d5be357

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed
File renamed without changes.
File renamed without changes.

generative-ai/snippets/test/gemma2Prediction.test.js renamed to ai-platform/snippets/test/gemma2Prediction.test.js

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,47 @@ const tpuResponse =
3737
'The sky appears blue due to a phenomenon called **Rayleigh scattering**.';
3838

3939
describe('Gemma2 predictions', async () => {
40+
const gemma2Endpoint =
41+
'projects/your-project-id/locations/your-vertex-endpoint-region/endpoints/your-vertex-endpoint-id';
42+
const configValues = {
43+
maxOutputTokens: {kind: 'numberValue', numberValue: 1024},
44+
temperature: {kind: 'numberValue', numberValue: 0.9},
45+
topP: {kind: 'numberValue', numberValue: 1},
46+
topK: {kind: 'numberValue', numberValue: 1},
47+
};
48+
const prompt = 'Why is the sky blue?';
4049
const predictionServiceClientMock = {
4150
predict: sinon.stub().resolves([]),
4251
};
4352

4453
afterEach(() => {
45-
sinon.restore();
54+
sinon.reset();
4655
});
4756

4857
it('should run interference with GPU', async () => {
58+
const expectedGpuRequest = {
59+
endpoint: gemma2Endpoint,
60+
instances: [
61+
{
62+
kind: 'structValue',
63+
structValue: {
64+
fields: {
65+
inputs: {
66+
kind: 'stringValue',
67+
stringValue: prompt,
68+
},
69+
parameters: {
70+
kind: 'structValue',
71+
structValue: {
72+
fields: configValues,
73+
},
74+
},
75+
},
76+
},
77+
},
78+
],
79+
};
80+
4981
predictionServiceClientMock.predict.resolves([
5082
{
5183
predictions: [
@@ -59,9 +91,30 @@ describe('Gemma2 predictions', async () => {
5991
const output = await gemma2PredictGpu(predictionServiceClientMock);
6092

6193
expect(output).include('Rayleigh scattering');
94+
expect(predictionServiceClientMock.predict.calledOnce).to.be.true;
95+
expect(predictionServiceClientMock.predict.calledWith(expectedGpuRequest))
96+
.to.be.true;
6297
});
6398

6499
it('should run interference with TPU', async () => {
100+
const expectedTpuRequest = {
101+
endpoint: gemma2Endpoint,
102+
instances: [
103+
{
104+
kind: 'structValue',
105+
structValue: {
106+
fields: {
107+
...configValues,
108+
prompt: {
109+
kind: 'stringValue',
110+
stringValue: prompt,
111+
},
112+
},
113+
},
114+
},
115+
],
116+
};
117+
65118
predictionServiceClientMock.predict.resolves([
66119
{
67120
predictions: [
@@ -75,5 +128,8 @@ describe('Gemma2 predictions', async () => {
75128
const output = await gemma2PredictTpu(predictionServiceClientMock);
76129

77130
expect(output).include('Rayleigh scattering');
131+
expect(predictionServiceClientMock.predict.calledOnce).to.be.true;
132+
expect(predictionServiceClientMock.predict.calledWith(expectedTpuRequest))
133+
.to.be.true;
78134
});
79135
});

0 commit comments

Comments
 (0)