@@ -30,9 +30,11 @@ import {
3030import {
3131 Availability ,
3232 LanguageModel ,
33+ LanguageModelExpected ,
3334 LanguageModelMessage ,
3435 LanguageModelMessageContent ,
35- LanguageModelMessageRole
36+ LanguageModelMessageRole ,
37+ LanguageModelMessageType
3638} from '../types/language-model' ;
3739
3840/**
@@ -48,13 +50,10 @@ export class ChromeAdapter {
4850 constructor (
4951 private languageModelProvider ?: LanguageModel ,
5052 private mode ?: InferenceMode ,
51- private onDeviceParams : OnDeviceParams = {
52- createOptions : {
53- // Defaults to support image inputs for convenience.
54- expectedInputs : [ { type : 'image' } ]
55- }
56- }
57- ) { }
53+ private onDeviceParams : OnDeviceParams = { }
54+ ) {
55+ this . onDeviceParams . createOptions ??= { } ;
56+ }
5857
5958 /**
6059 * Checks if a given request can be made on-device.
@@ -85,8 +84,10 @@ export class ChromeAdapter {
8584 return false ;
8685 }
8786
87+ const expectedInputs = ChromeAdapter . extractExpectedInputs ( request ) ;
88+
8889 // Triggers out-of-band download so model will eventually become available.
89- const availability = await this . downloadIfAvailable ( ) ;
90+ const availability = await this . downloadIfAvailable ( expectedInputs ) ;
9091
9192 if ( this . mode === 'only_on_device' ) {
9293 return true ;
@@ -158,6 +159,33 @@ export class ChromeAdapter {
158159 ) ;
159160 }
160161
162+ /**
163+ * Maps
164+ * <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#blob">
165+ * Vertex's input mime types</a> to
166+ * <a href="https://github.com/webmachinelearning/prompt-api?tab=readme-ov-file#full-api-surface-in-web-idl">
167+ * Chrome's expected types</a>.
168+ *
169+ * <p>Chrome's API checks availability by type. It's tedious to specify the types in advance, so
170+ * this method infers the types.</p>
171+ */
172+ private static extractExpectedInputs (
173+ request : GenerateContentRequest
174+ ) : LanguageModelExpected [ ] {
175+ const inputSet = new Set < LanguageModelExpected > ( ) ;
176+ for ( const content of request . contents ) {
177+ for ( const part of content . parts ) {
178+ if ( part . inlineData ) {
179+ const type = part . inlineData . mimeType . split (
180+ '/'
181+ ) [ 0 ] as LanguageModelMessageType ;
182+ inputSet . add ( { type } ) ;
183+ }
184+ }
185+ }
186+ return Array . from ( inputSet ) ;
187+ }
188+
161189 /**
162190 * Asserts inference for the given request can be performed by an on-device model.
163191 */
@@ -196,12 +224,21 @@ export class ChromeAdapter {
196224 /**
197225 * Encapsulates logic to get availability and download a model if one is downloadable.
198226 */
199- private async downloadIfAvailable ( ) : Promise < Availability | undefined > {
227+ private async downloadIfAvailable (
228+ expectedInputs : LanguageModelExpected [ ]
229+ ) : Promise < Availability | undefined > {
230+ // Side-effect: updates construction-time params with request-time params.
231+ // This is required because params are referenced through multiple flows.
232+ // TODO: remove this side effect, since we need to also pass options when creating a session.
233+ Object . assign ( this . onDeviceParams . createOptions ! , { expectedInputs } ) ;
234+
200235 const availability = await this . languageModelProvider ?. availability (
201236 this . onDeviceParams . createOptions
202237 ) ;
203238
204239 if ( availability === Availability . downloadable ) {
240+ // Side-effect: triggers out-of-band model download.
241+ // This is required because Chrome manages the model download.
205242 this . download ( ) ;
206243 }
207244
0 commit comments