Skip to content

Commit 495282b

Browse files
authored
Implement ZeroShotImageClassificationWidget (#322)
* Implement `ZeroShotImageClassificationWidget` * Fix outout parsing * Add model example
1 parent d63e4cd commit 495282b

File tree

3 files changed

+218
-0
lines changed

3 files changed

+218
-0
lines changed

js/src/lib/components/InferenceWidget/InferenceWidget.svelte

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import TabularDataWidget from "./widgets/TabularDataWidget/TabularDataWidget.svelte";
2626
import ReinforcementLearningWidget from "./widgets/ReinforcementLearningWidget/ReinforcementLearningWidget.svelte";
2727
import ZeroShotClassificationWidget from "./widgets/ZeroShowClassificationWidget/ZeroShotClassificationWidget.svelte";
28+
import ZeroShotImageClassificationWidget from "./widgets/ZeroShotImageClassificationWidget/ZeroShotImageClassificationWidget.svelte";
2829
2930
export let apiToken: WidgetProps["apiToken"] = undefined;
3031
export let callApiOnMount = false;
@@ -70,6 +71,7 @@
7071
"reinforcement-learning": ReinforcementLearningWidget,
7172
"zero-shot-classification": ZeroShotClassificationWidget,
7273
"document-question-answering": VisualQuestionAnsweringWidget,
74+
"zero-shot-image-classification": ZeroShotImageClassificationWidget,
7375
};
7476
7577
$: widgetComponent = WIDGET_COMPONENTS[model.pipeline_tag ?? ""];
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
<script lang="ts">
2+
import type { WidgetProps } from "../../shared/types";
3+
4+
import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
5+
import WidgetDropzone from "../../shared/WidgetDropzone/WidgetDropzone.svelte";
6+
import WidgetTextInput from "../../shared/WidgetTextInput/WidgetTextInput.svelte";
7+
import WidgetSubmitBtn from "../../shared/WidgetSubmitBtn/WidgetSubmitBtn.svelte";
8+
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte";
9+
import WidgetOutputChart from "../../shared/WidgetOutputChart/WidgetOutputChart.svelte";
10+
import { addInferenceParameters, getResponse } from "../../shared/helpers";
11+
12+
export let apiToken: WidgetProps["apiToken"];
13+
export let apiUrl: WidgetProps["apiUrl"];
14+
export let model: WidgetProps["model"];
15+
export let noTitle: WidgetProps["noTitle"];
16+
export let includeCredentials: WidgetProps["includeCredentials"];
17+
18+
let candidateLabels = "";
19+
let computeTime = "";
20+
let error: string = "";
21+
let isLoading = false;
22+
let modelLoading = {
23+
isLoading: false,
24+
estimatedTime: 0,
25+
};
26+
let output: Array<{ label: string; score: number }> = [];
27+
let outputJson: string;
28+
let imgSrc = "";
29+
let imageBase64 = "";
30+
31+
async function onSelectFile(file: File | Blob) {
32+
imgSrc = URL.createObjectURL(file);
33+
await updateImageBase64(file);
34+
}
35+
36+
function updateImageBase64(file: File | Blob): Promise<void> {
37+
return new Promise((resolve, reject) => {
38+
let fileReader: FileReader = new FileReader();
39+
fileReader.onload = async () => {
40+
try {
41+
const imageBase64WithPrefix: string = fileReader.result as string;
42+
imageBase64 = imageBase64WithPrefix.split(",")[1]; // remove prefix
43+
isLoading = false;
44+
resolve();
45+
} catch (err) {
46+
reject(err);
47+
}
48+
};
49+
fileReader.onerror = (e) => reject(e);
50+
isLoading = true;
51+
fileReader.readAsDataURL(file);
52+
});
53+
}
54+
55+
function isValidOutput(arg: any): arg is { label: string; score: number }[] {
56+
return (
57+
Array.isArray(arg) &&
58+
arg.every(
59+
(x) => typeof x.label === "string" && typeof x.score === "number"
60+
)
61+
);
62+
}
63+
64+
function parseOutput(body: unknown): Array<{ label: string; score: number }> {
65+
if (isValidOutput(body)) {
66+
return body;
67+
}
68+
throw new TypeError(
69+
"Invalid output: output must be of type <labels:Array; scores:Array>"
70+
);
71+
}
72+
73+
function previewInputSample(sample: Record<string, any>) {
74+
candidateLabels = sample.candidate_labels;
75+
imgSrc = sample.src;
76+
}
77+
78+
async function applyInputSample(sample: Record<string, any>) {
79+
candidateLabels = sample.candidate_labels;
80+
imgSrc = sample.src;
81+
const res = await fetch(imgSrc);
82+
const blob = await res.blob();
83+
await updateImageBase64(blob);
84+
getOutput();
85+
}
86+
87+
async function getOutput(withModelLoading = false) {
88+
const trimmedCandidateLabels = candidateLabels.trim().split(",").join(",");
89+
90+
if (!trimmedCandidateLabels) {
91+
error = "You need to input at least one label";
92+
output = [];
93+
outputJson = "";
94+
return;
95+
}
96+
97+
if (!imageBase64) {
98+
error = "You need to upload an image";
99+
output = [];
100+
outputJson = "";
101+
return;
102+
}
103+
104+
const requestBody = {
105+
image: imageBase64,
106+
parameters: {
107+
candidate_labels: trimmedCandidateLabels,
108+
},
109+
};
110+
addInferenceParameters(requestBody, model);
111+
112+
isLoading = true;
113+
114+
const res = await getResponse(
115+
apiUrl,
116+
model.id,
117+
requestBody,
118+
apiToken,
119+
parseOutput,
120+
withModelLoading,
121+
includeCredentials
122+
);
123+
124+
isLoading = false;
125+
// Reset values
126+
computeTime = "";
127+
error = "";
128+
modelLoading = { isLoading: false, estimatedTime: 0 };
129+
output = [];
130+
outputJson = "";
131+
132+
if (res.status === "success") {
133+
computeTime = res.computeTime;
134+
output = res.output;
135+
outputJson = res.outputJson;
136+
} else if (res.status === "loading-model") {
137+
modelLoading = {
138+
isLoading: true,
139+
estimatedTime: res.estimatedTime,
140+
};
141+
getOutput(true);
142+
} else if (res.status === "error") {
143+
error = res.error;
144+
}
145+
}
146+
</script>
147+
148+
<WidgetWrapper
149+
{apiUrl}
150+
{applyInputSample}
151+
{computeTime}
152+
{error}
153+
{isLoading}
154+
{model}
155+
{modelLoading}
156+
{noTitle}
157+
{outputJson}
158+
{previewInputSample}
159+
>
160+
<svelte:fragment slot="top">
161+
<form class="space-y-2">
162+
<WidgetDropzone
163+
classNames="no-hover:hidden"
164+
{isLoading}
165+
{imgSrc}
166+
{onSelectFile}
167+
onError={(e) => (error = e)}
168+
>
169+
{#if imgSrc}
170+
<img
171+
src={imgSrc}
172+
class="pointer-events-none shadow mx-auto max-h-44"
173+
alt=""
174+
/>
175+
{/if}
176+
</WidgetDropzone>
177+
<!-- Better UX for mobile/table through CSS breakpoints -->
178+
{#if imgSrc}
179+
{#if imgSrc}
180+
<div
181+
class="mb-2 flex justify-center bg-gray-50 dark:bg-gray-900 with-hover:hidden"
182+
>
183+
<img src={imgSrc} class="pointer-events-none max-h-44" alt="" />
184+
</div>
185+
{/if}
186+
{/if}
187+
<WidgetFileInput
188+
accept="image/*"
189+
classNames="mr-2 with-hover:hidden"
190+
{isLoading}
191+
label="Browse for image"
192+
{onSelectFile}
193+
/>
194+
<WidgetTextInput
195+
bind:value={candidateLabels}
196+
label="Possible class names (comma-separated)"
197+
placeholder="Possible class names..."
198+
/>
199+
<WidgetSubmitBtn
200+
{isLoading}
201+
onClick={() => {
202+
getOutput();
203+
}}
204+
/>
205+
</form>
206+
</svelte:fragment>
207+
<svelte:fragment slot="bottom">
208+
{#if output.length}
209+
<WidgetOutputChart classNames="pt-4" {output} />
210+
{/if}
211+
</svelte:fragment>
212+
</WidgetWrapper>

js/src/routes/index.svelte

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import type { ModelData } from "../lib/interfaces/Types";
55
66
const models: ModelData[] = [
7+
{
8+
id: "openai/clip-vit-base-patch16",
9+
pipeline_tag: "zero-shot-image-classification",
10+
},
711
{
812
id: "ydshieh/vit-gpt2-coco-en",
913
pipeline_tag: "image-to-text",

0 commit comments

Comments
 (0)