@@ -30,12 +30,14 @@ import {
3030import {
3131 Availability ,
3232 LanguageModel ,
33+ LanguageModelCreateOptions ,
3334 LanguageModelExpected ,
3435 LanguageModelMessage ,
3536 LanguageModelMessageContent ,
3637 LanguageModelMessageRole ,
3738 LanguageModelMessageType
3839} from '../types/language-model' ;
40+ import { deepExtend } from '@firebase/util' ;
3941
4042/**
4143 * Defines an inference "backend" that uses Chrome's on-device model,
@@ -51,9 +53,7 @@ export class ChromeAdapter {
5153 private languageModelProvider ?: LanguageModel ,
5254 private mode ?: InferenceMode ,
5355 private onDeviceParams : OnDeviceParams = { }
54- ) {
55- this . onDeviceParams . createOptions ??= { } ;
56- }
56+ ) { }
5757
5858 /**
5959 * Checks if a given request can be made on-device.
@@ -84,10 +84,11 @@ export class ChromeAdapter {
8484 return false ;
8585 }
8686
87- const expectedInputs = ChromeAdapter . extractExpectedInputs ( request ) ;
87+ const requestOptions = this . inferCreateOptions ( request ) ;
88+ const mergedOptions = this . mergeCreateOptions ( requestOptions ) ;
8889
8990 // Triggers out-of-band download so model will eventually become available.
90- const availability = await this . downloadIfAvailable ( expectedInputs ) ;
91+ const availability = await this . downloadIfAvailable ( mergedOptions ) ;
9192
9293 if ( this . mode === 'only_on_device' ) {
9394 return true ;
@@ -119,7 +120,9 @@ export class ChromeAdapter {
119120 * @returns {@link Response }, so we can reuse common response formatting.
120121 */
121122 async generateContent ( request : GenerateContentRequest ) : Promise < Response > {
122- const session = await this . createSession ( ) ;
123+ const requestOptions = this . inferCreateOptions ( request ) ;
124+ const mergedOptions = this . mergeCreateOptions ( requestOptions ) ;
125+ const session = await this . createSession ( mergedOptions ) ;
123126 const contents = await Promise . all (
124127 request . contents . map ( ChromeAdapter . toLanguageModelMessage )
125128 ) ;
@@ -141,7 +144,9 @@ export class ChromeAdapter {
141144 async generateContentStream (
142145 request : GenerateContentRequest
143146 ) : Promise < Response > {
144- const session = await this . createSession ( ) ;
147+ const inferredOptions = this . inferCreateOptions ( request ) ;
148+ const mergedOptions = this . mergeCreateOptions ( inferredOptions ) ;
149+ const session = await this . createSession ( mergedOptions ) ;
145150 const contents = await Promise . all (
146151 request . contents . map ( ChromeAdapter . toLanguageModelMessage )
147152 ) ;
@@ -164,14 +169,14 @@ export class ChromeAdapter {
164169 * <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#blob">
165170 * Vertex's input mime types</a> to
166171 * <a href="https://github.com/webmachinelearning/prompt-api?tab=readme-ov-file#full-api-surface-in-web-idl">
167- * Chrome's expected types</a>.
172+ * Chrome's expected input types</a>.
168173 *
169174 * <p>Chrome's API checks availability by type. It's tedious to specify the types in advance, so
170175 * this method infers the types.</p>
171176 */
172- private static extractExpectedInputs (
177+ private inferCreateOptions (
173178 request : GenerateContentRequest
174- ) : LanguageModelExpected [ ] {
179+ ) : LanguageModelCreateOptions {
175180 const inputSet = new Set < LanguageModelExpected > ( ) ;
176181 for ( const content of request . contents ) {
177182 for ( const part of content . parts ) {
@@ -183,7 +188,23 @@ export class ChromeAdapter {
183188 }
184189 }
185190 }
186- return Array . from ( inputSet ) ;
191+
192+ return {
193+ expectedInputs : Array . from ( inputSet )
194+ } ;
195+ }
196+
197+ /**
198+ * Assembles a unified {@link LanguageModelCreateOptions} from create- and request-time options.
199+ * Request-time options take priority over create-time options.
200+ */
201+ private mergeCreateOptions (
202+ requestOptions : LanguageModelCreateOptions
203+ ) : LanguageModelCreateOptions {
204+ return deepExtend (
205+ this . onDeviceParams . createOptions ,
206+ requestOptions
207+ ) as LanguageModelCreateOptions ;
187208 }
188209
189210 /**
@@ -225,15 +246,10 @@ export class ChromeAdapter {
225246 * Encapsulates logic to get availability and download a model if one is downloadable.
226247 */
227248 private async downloadIfAvailable (
228- expectedInputs : LanguageModelExpected [ ]
249+ createOptions : LanguageModelCreateOptions
229250 ) : Promise < Availability | undefined > {
230- // Side-effect: updates construction-time params with request-time params.
231- // This is required because params are referenced through multiple flows.
232- // TODO: remove this side effect, since we need to also pass options when creating a session.
233- Object . assign ( this . onDeviceParams . createOptions ! , { expectedInputs } ) ;
234-
235251 const availability = await this . languageModelProvider ?. availability (
236- this . onDeviceParams . createOptions
252+ createOptions
237253 ) ;
238254
239255 if ( availability === Availability . downloadable ) {
@@ -328,16 +344,16 @@ export class ChromeAdapter {
328344 * <p>Chrome will remove a model from memory if it's no longer in use, so this method ensures a
329345 * new session is created before an old session is destroyed.</p>
330346 */
331- private async createSession ( ) : Promise < LanguageModel > {
347+ private async createSession (
348+ createOptions : LanguageModelCreateOptions
349+ ) : Promise < LanguageModel > {
332350 if ( ! this . languageModelProvider ) {
333351 throw new AIError (
334352 AIErrorCode . REQUEST_ERROR ,
335353 'Chrome AI requested for unsupported browser version.'
336354 ) ;
337355 }
338- const newSession = await this . languageModelProvider . create (
339- this . onDeviceParams . createOptions
340- ) ;
356+ const newSession = await this . languageModelProvider . create ( createOptions ) ;
341357 if ( this . oldSession ) {
342358 this . oldSession . destroy ( ) ;
343359 }
0 commit comments