Skip to content

Commit a717ba0

Browse files
committed
fix: tableQuestionAnswering task
1 parent ccd9e84 commit a717ba0

File tree

5 files changed

+268
-71
lines changed

5 files changed

+268
-71
lines changed

extensions/firestore-huggingface-inference-api/functions/src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {
77
HfInferenceEndpoint,
88
HfInference,
99
} from '@huggingface/inference';
10-
import { runHostedInference as runInference } from './run_inference';
10+
import { runInference } from './run_inference';
1111

1212
/**
1313
* Trigger inference on Firestore document creation.
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import { describe, expect, test } from '@jest/globals';
2+
import { runInference } from './run_inference';
3+
import { Task } from './tasks';
4+
import { HfInference, HfInferenceEndpoint } from '@huggingface/inference';
5+
import { DocumentSnapshot } from 'firebase-admin/firestore';
6+
7+
const inference = new HfInference('test');
8+
9+
jest.mock('@huggingface/inference', () => ({
10+
HfInference: jest.fn().mockImplementation(() => ({
11+
fillMask: jest.fn(),
12+
summarization: jest.fn(),
13+
questionAnswering: jest.fn(),
14+
tableQuestionAnswering: jest.fn(),
15+
})),
16+
HfInferenceEndpoint: jest.fn().mockImplementation(() => ({
17+
fillMask: jest.fn(),
18+
summarization: jest.fn(),
19+
questionAnswering: jest.fn(),
20+
tableQuestionAnswering: jest.fn(),
21+
})),
22+
}));
23+
24+
describe(Task.fillMask, () => {
25+
test('should throw an error if inputs are not provided', async () => {
26+
const snapshot = {
27+
data: () => ({
28+
inputs: undefined,
29+
}),
30+
} as any;
31+
32+
await expect(
33+
runInference(snapshot, Task.fillMask, inference),
34+
).rejects.toThrow(Error);
35+
});
36+
37+
test('should throw an error if inputs are not a string', async () => {
38+
const snapshot = {
39+
data: () => ({
40+
inputs: 123,
41+
}),
42+
} as any;
43+
44+
await expect(
45+
runInference(snapshot, Task.fillMask, inference),
46+
).rejects.toThrow(Error);
47+
});
48+
});
49+
50+
describe(Task.summarization, () => {
51+
test('should throw an error if inputs are not provided', async () => {
52+
const snapshot = {
53+
data: () => ({
54+
inputs: undefined,
55+
}),
56+
} as any;
57+
58+
await expect(
59+
runInference(snapshot, Task.summarization, inference),
60+
).rejects.toThrow(Error);
61+
});
62+
63+
test('should throw an error if inputs are not a string', async () => {
64+
const snapshot = {
65+
data: () => ({
66+
inputs: 123,
67+
}),
68+
} as any;
69+
70+
await expect(
71+
runInference(snapshot, Task.summarization, inference),
72+
).rejects.toThrow(Error);
73+
});
74+
});
75+
76+
// Testing questionAnswering task.
77+
describe(Task.questionAnswering, () => {
78+
let mockTask: Task;
79+
let inference: HfInference | HfInferenceEndpoint;
80+
81+
beforeAll(() => {
82+
mockTask = Task.questionAnswering;
83+
inference = new HfInference();
84+
});
85+
86+
it('should run without errors', async () => {
87+
const snapshot = {
88+
data: () => ({
89+
question: 'Who are you?',
90+
context: 'You are a cat.',
91+
}),
92+
} as any;
93+
94+
await expect(
95+
runInference(snapshot, mockTask, inference),
96+
).resolves.not.toThrow();
97+
});
98+
99+
it('should throw if `question` and `context` are not provided', async () => {
100+
const snapshot = {
101+
data: () => ({}),
102+
} as any;
103+
104+
await expect(runInference(snapshot, mockTask, inference)).rejects.toThrow();
105+
});
106+
107+
it('should throw if `question` is not a string', async () => {
108+
const snapshot = {
109+
data: () => ({
110+
question: 123,
111+
}),
112+
} as any;
113+
114+
await expect(runInference(snapshot, mockTask, inference)).rejects.toThrow();
115+
});
116+
117+
it('should throw if `context` is not a string', async () => {
118+
const snapshot = {
119+
data: () => ({
120+
context: 123,
121+
}),
122+
} as any;
123+
124+
await expect(runInference(snapshot, mockTask, inference)).rejects.toThrow();
125+
});
126+
127+
it('should run with inference endpoint', async () => {
128+
const snapshot = {
129+
data: () => ({
130+
question: 'Who are you?',
131+
context: 'You are a cat.',
132+
}),
133+
} as any;
134+
135+
inference = new HfInferenceEndpoint('https://endpoint-url.com');
136+
137+
await expect(
138+
runInference(snapshot, mockTask, inference),
139+
).resolves.not.toThrow();
140+
});
141+
});
142+
143+
// Testing tableQuestionAnswering task.
144+
describe(Task.tableQuestionAnswering, () => {
145+
let mockTask: Task;
146+
let inference: HfInference | HfInferenceEndpoint;
147+
const snapshot = {
148+
data: () => ({
149+
inputs: {
150+
query: 'Who are you?',
151+
table: {
152+
name: ['John', 'Mary'],
153+
age: ['20', '30'],
154+
},
155+
},
156+
}),
157+
} as any;
158+
159+
beforeAll(() => {
160+
mockTask = Task.tableQuestionAnswering;
161+
inference = new HfInference();
162+
});
163+
164+
it('should run without errors', async () => {
165+
await expect(
166+
runInference(snapshot, mockTask, inference),
167+
).resolves.not.toThrow();
168+
});
169+
170+
it('should throw if `question` and `table` are not provided', async () => {
171+
const snapshot = {
172+
data: () => ({}),
173+
} as any;
174+
175+
await expect(runInference(snapshot, mockTask, inference)).rejects.toThrow();
176+
});
177+
178+
it('should throw if `query` is not a string', async () => {
179+
const snapshot = {
180+
data: () => ({
181+
inputs: {
182+
query: 123,
183+
table: {
184+
name: ['John', 'Mary'],
185+
age: ['20', '30'],
186+
},
187+
},
188+
}),
189+
} as any;
190+
191+
await expect(runInference(snapshot, mockTask, inference)).rejects.toThrow();
192+
});
193+
194+
it('should throw if `table` is not valid', async () => {
195+
const snapshot = {
196+
data: () => ({
197+
inputs: {
198+
query: '123',
199+
table: 123,
200+
},
201+
}),
202+
} as any;
203+
204+
await expect(runInference(snapshot, mockTask, inference)).rejects.toThrow();
205+
});
206+
207+
it('should run with inference endpoint', async () => {
208+
inference = new HfInferenceEndpoint('https://endpoint-url.com');
209+
210+
await expect(
211+
runInference(snapshot, mockTask, inference),
212+
).resolves.not.toThrow();
213+
});
214+
});

extensions/firestore-huggingface-inference-api/functions/src/run_inference.ts

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ import {
1616

1717
import config from './config';
1818
import { Task } from './tasks';
19+
import { FirestoreInput } from './types/table';
1920

2021
/**
2122
* Validate inputs and create a task.
2223
*
2324
* @param {functions.firestore.DocumentSnapshot} snapshot
2425
* @return {Task}
2526
*/
26-
export async function runHostedInference(
27+
export async function runInference(
2728
snapshot: functions.firestore.DocumentSnapshot,
2829
task: Task,
2930
inference: HfInference | HfInferenceEndpoint,
@@ -55,6 +56,7 @@ export async function runHostedInference(
5556

5657
return await inference.fillMask(options);
5758
}
59+
5860
case Task.summarization: {
5961
const { inputs } = snapshot.data();
6062
checkInputs(inputs);
@@ -96,24 +98,21 @@ export async function runHostedInference(
9698
}
9799

98100
case Task.tableQuestionAnswering: {
99-
const { query, table } = snapshot.data() as {
100-
query?: string;
101-
table?: Record<string, string[]>;
102-
};
103-
104-
if (!query || !table || typeof query !== 'string')
101+
const { inputs } = snapshot.data();
102+
if (
103+
!inputs.query ||
104+
typeof inputs.query !== 'string' ||
105+
!validateFirestoreInput({ inputs: inputs })
106+
)
105107
throw new Error(
106-
'Field `query` and `table` are required and must be a string and an array of strings respectively',
108+
'Field `query` and `table` are required and must be a string and a dictionary respectively',
107109
);
108110

109111
return await inference.tableQuestionAnswering({
110112
...(inference instanceof HfInference && {
111113
model: config.modelId,
112114
}),
113-
inputs: {
114-
query,
115-
table,
116-
},
115+
inputs: inputs,
117116
});
118117
}
119118

@@ -254,3 +253,33 @@ function checkListInputs(inputs: any) {
254253
);
255254
}
256255
}
256+
257+
function validateFirestoreInput(data: any): data is FirestoreInput {
258+
if (!data || typeof data !== 'object' || !data.inputs) {
259+
return false;
260+
}
261+
262+
const inputs = data.inputs;
263+
264+
if (typeof inputs.query !== 'string') {
265+
return false;
266+
}
267+
268+
const table = inputs.table;
269+
270+
if (!table || typeof table !== 'object') {
271+
return false;
272+
}
273+
274+
// Check all properties of table are string arrays
275+
for (const key in table) {
276+
if (
277+
!Array.isArray(table[key]) ||
278+
!table[key].every((item: any) => typeof item === 'string')
279+
) {
280+
return false;
281+
}
282+
}
283+
284+
return true;
285+
}

extensions/firestore-huggingface-inference-api/functions/src/tasks.test.ts

Lines changed: 0 additions & 58 deletions
This file was deleted.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
interface TableData {
2+
[k: string]: string[];
3+
}
4+
5+
interface InferenceInput {
6+
query: string;
7+
table: TableData;
8+
}
9+
10+
export interface FirestoreInput {
11+
inputs: InferenceInput;
12+
}

0 commit comments

Comments
 (0)