@@ -37,15 +37,47 @@ const tpuResponse =
37
37
'The sky appears blue due to a phenomenon called **Rayleigh scattering**.' ;
38
38
39
39
describe ( 'Gemma2 predictions' , async ( ) => {
40
+ const gemma2Endpoint =
41
+ 'projects/your-project-id/locations/your-vertex-endpoint-region/endpoints/your-vertex-endpoint-id' ;
42
+ const configValues = {
43
+ maxOutputTokens : { kind : 'numberValue' , numberValue : 1024 } ,
44
+ temperature : { kind : 'numberValue' , numberValue : 0.9 } ,
45
+ topP : { kind : 'numberValue' , numberValue : 1 } ,
46
+ topK : { kind : 'numberValue' , numberValue : 1 } ,
47
+ } ;
48
+ const prompt = 'Why is the sky blue?' ;
40
49
const predictionServiceClientMock = {
41
50
predict : sinon . stub ( ) . resolves ( [ ] ) ,
42
51
} ;
43
52
44
53
afterEach ( ( ) => {
45
- sinon . restore ( ) ;
54
+ sinon . reset ( ) ;
46
55
} ) ;
47
56
48
57
it ( 'should run interference with GPU' , async ( ) => {
58
+ const expectedGpuRequest = {
59
+ endpoint : gemma2Endpoint ,
60
+ instances : [
61
+ {
62
+ kind : 'structValue' ,
63
+ structValue : {
64
+ fields : {
65
+ inputs : {
66
+ kind : 'stringValue' ,
67
+ stringValue : prompt ,
68
+ } ,
69
+ parameters : {
70
+ kind : 'structValue' ,
71
+ structValue : {
72
+ fields : configValues ,
73
+ } ,
74
+ } ,
75
+ } ,
76
+ } ,
77
+ } ,
78
+ ] ,
79
+ } ;
80
+
49
81
predictionServiceClientMock . predict . resolves ( [
50
82
{
51
83
predictions : [
@@ -59,9 +91,30 @@ describe('Gemma2 predictions', async () => {
59
91
const output = await gemma2PredictGpu ( predictionServiceClientMock ) ;
60
92
61
93
expect ( output ) . include ( 'Rayleigh scattering' ) ;
94
+ expect ( predictionServiceClientMock . predict . calledOnce ) . to . be . true ;
95
+ expect ( predictionServiceClientMock . predict . calledWith ( expectedGpuRequest ) )
96
+ . to . be . true ;
62
97
} ) ;
63
98
64
99
it ( 'should run interference with TPU' , async ( ) => {
100
+ const expectedTpuRequest = {
101
+ endpoint : gemma2Endpoint ,
102
+ instances : [
103
+ {
104
+ kind : 'structValue' ,
105
+ structValue : {
106
+ fields : {
107
+ ...configValues ,
108
+ prompt : {
109
+ kind : 'stringValue' ,
110
+ stringValue : prompt ,
111
+ } ,
112
+ } ,
113
+ } ,
114
+ } ,
115
+ ] ,
116
+ } ;
117
+
65
118
predictionServiceClientMock . predict . resolves ( [
66
119
{
67
120
predictions : [
@@ -75,5 +128,8 @@ describe('Gemma2 predictions', async () => {
75
128
const output = await gemma2PredictTpu ( predictionServiceClientMock ) ;
76
129
77
130
expect ( output ) . include ( 'Rayleigh scattering' ) ;
131
+ expect ( predictionServiceClientMock . predict . calledOnce ) . to . be . true ;
132
+ expect ( predictionServiceClientMock . predict . calledWith ( expectedTpuRequest ) )
133
+ . to . be . true ;
78
134
} ) ;
79
135
} ) ;
0 commit comments