Skip to content

Commit f534174

Browse files
committed
Add support for node.js external data & optimize downloading
1 parent 33a60c5 commit f534174

File tree

2 files changed

+113
-70
lines changed

2 files changed

+113
-70
lines changed

src/models.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
250250

251251
// handle onnx external data files
252252
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
253-
/** @type {Promise<{path: string, data: Uint8Array}>[]} */
253+
/** @type {Promise<string|{path: string, data: Uint8Array}>[]} */
254254
let externalDataPromises = [];
255255
if (use_external_data_format) {
256256
let external_data_format;
@@ -272,8 +272,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
272272
const path = `${fileName}${suffix}.onnx_data${i === 0 ? '' : '_' + i}`;
273273
const fullPath = `${options.subfolder ?? ''}/${path}`;
274274
externalDataPromises.push(new Promise(async (resolve, reject) => {
275-
const data = /** @type {Uint8Array} */ (await getModelFile(pretrained_model_name_or_path, fullPath, true, options, false));
276-
resolve({ path, data })
275+
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, apis.IS_NODE_ENV);
276+
resolve(data instanceof Uint8Array ? { path, data } : path);
277277
}));
278278
}
279279

src/utils/hub.js

Lines changed: 110 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import fs from 'fs';
99
import path from 'path';
1010

11-
import { env } from '../env.js';
11+
import { apis, env } from '../env.js';
1212
import { dispatchCallback } from './core.js';
1313

1414
/**
@@ -288,29 +288,61 @@ class FileCache {
288288
/**
289289
* Adds the given response to the cache.
290290
* @param {string} request
291-
* @param {Response|FileResponse} response
291+
* @param {Response} response
292+
* @param {(data: {progress: number, loaded: number, total: number}) => void} [progress_callback] Optional.
293+
* The function to call with progress updates
292294
* @returns {Promise<void>}
293295
*/
294-
async put(request, response) {
295-
const buffer = Buffer.from(await response.arrayBuffer());
296-
297-
let outputPath = path.join(this.path, request);
296+
async put(request, response, progress_callback = undefined) {
297+
let filePath = path.join(this.path, request);
298298

299299
try {
300-
await fs.promises.mkdir(path.dirname(outputPath), { recursive: true });
301-
await fs.promises.writeFile(outputPath, buffer);
300+
const contentLength = response.headers.get('Content-Length');
301+
const total = parseInt(contentLength ?? '0');
302+
let loaded = 0;
303+
304+
await fs.promises.mkdir(path.dirname(filePath), { recursive: true });
305+
const fileStream = fs.createWriteStream(filePath);
306+
const reader = response.body.getReader();
307+
308+
while (true) {
309+
const { done, value } = await reader.read();
310+
if (done) {
311+
break;
312+
}
313+
314+
await new Promise((resolve, reject) => {
315+
fileStream.write(value, (err) => {
316+
if (err) {
317+
reject(err);
318+
return;
319+
}
320+
resolve();
321+
});
322+
});
323+
324+
loaded += value.length;
325+
const progress = total ? (loaded / total) * 100 : 0;
326+
327+
progress_callback?.({ progress, loaded, total });
328+
}
302329

303-
} catch (err) {
304-
console.warn('An error occurred while writing the file to cache:', err)
330+
fileStream.close();
331+
} catch (error) {
332+
// Clean up the file if an error occurred during download
333+
try {
334+
await fs.promises.unlink(filePath);
335+
} catch { }
336+
throw error;
305337
}
306-
}
307338

308-
// TODO add the rest?
309-
// addAll(requests: RequestInfo[]): Promise<void>;
310-
// delete(request: RequestInfo | URL, options?: CacheQueryOptions): Promise<boolean>;
311-
// keys(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise<ReadonlyArray<Request>>;
312-
// match(request: RequestInfo | URL, options?: CacheQueryOptions): Promise<Response | undefined>;
313-
// matchAll(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise<ReadonlyArray<Response>>;
339+
// TODO add the rest?
340+
// addAll(requests: RequestInfo[]): Promise<void>;
341+
// delete(request: RequestInfo | URL, options?: CacheQueryOptions): Promise<boolean>;
342+
// keys(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise<ReadonlyArray<Request>>;
343+
// match(request: RequestInfo | URL, options?: CacheQueryOptions): Promise<Response | undefined>;
344+
// matchAll(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise<ReadonlyArray<Response>>;
345+
}
314346
}
315347

316348
/**
@@ -512,41 +544,45 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
512544
file: filename
513545
})
514546

515-
/** @type {Uint8Array} */
516-
let buffer;
517-
518-
if (!options.progress_callback) {
519-
// If no progress callback is specified, we can use the `.arrayBuffer()`
520-
// method to read the response.
521-
buffer = new Uint8Array(await response.arrayBuffer());
522-
523-
} else if (
524-
cacheHit // The item is being read from the cache
525-
&&
526-
typeof navigator !== 'undefined' && /firefox/i.test(navigator.userAgent) // We are in Firefox
527-
) {
528-
// Due to bug in Firefox, we cannot display progress when loading from cache.
529-
// Fortunately, since this should be instantaneous, this should not impact users too much.
530-
buffer = new Uint8Array(await response.arrayBuffer());
531-
532-
// For completeness, we still fire the final progress callback
533-
dispatchCallback(options.progress_callback, {
534-
status: 'progress',
535-
name: path_or_repo_id,
536-
file: filename,
537-
progress: 100,
538-
loaded: buffer.length,
539-
total: buffer.length,
540-
})
541-
} else {
542-
buffer = await readResponse(response, data => {
547+
let result;
548+
if (!(apis.IS_NODE_ENV && return_path)) {
549+
/** @type {Uint8Array} */
550+
let buffer;
551+
552+
if (!options.progress_callback) {
553+
// If no progress callback is specified, we can use the `.arrayBuffer()`
554+
// method to read the response.
555+
buffer = new Uint8Array(await response.arrayBuffer());
556+
557+
} else if (
558+
cacheHit // The item is being read from the cache
559+
&&
560+
typeof navigator !== 'undefined' && /firefox/i.test(navigator.userAgent) // We are in Firefox
561+
) {
562+
// Due to bug in Firefox, we cannot display progress when loading from cache.
563+
// Fortunately, since this should be instantaneous, this should not impact users too much.
564+
buffer = new Uint8Array(await response.arrayBuffer());
565+
566+
// For completeness, we still fire the final progress callback
543567
dispatchCallback(options.progress_callback, {
544568
status: 'progress',
545569
name: path_or_repo_id,
546570
file: filename,
547-
...data,
571+
progress: 100,
572+
loaded: buffer.length,
573+
total: buffer.length,
548574
})
549-
})
575+
} else {
576+
buffer = await readResponse(response, data => {
577+
dispatchCallback(options.progress_callback, {
578+
status: 'progress',
579+
name: path_or_repo_id,
580+
file: filename,
581+
...data,
582+
})
583+
})
584+
}
585+
result = buffer;
550586
}
551587

552588
if (
@@ -557,36 +593,43 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
557593
// Check again whether request is in cache. If not, we add the response to the cache
558594
(await cache.match(cacheKey) === undefined)
559595
) {
560-
// NOTE: We use `new Response(buffer, ...)` instead of `response.clone()` to handle LFS files
561-
await cache.put(cacheKey, new Response(buffer, {
562-
headers: response.headers
563-
}))
564-
.catch(err => {
565-
// Do not crash if unable to add to cache (e.g., QuotaExceededError).
566-
// Rather, log a warning and proceed with execution.
567-
console.warn(`Unable to add response to browser cache: ${err}.`);
568-
});
596+
if (!result) {
597+
// We haven't yet read the response body, so we need to do so now.
598+
await cache.put(cacheKey, /** @type {Response} */(response), options.progress_callback);
599+
} else {
600+
// NOTE: We use `new Response(buffer, ...)` instead of `response.clone()` to handle LFS files
601+
await cache.put(cacheKey, new Response(result, {
602+
headers: response.headers
603+
}))
604+
.catch(err => {
605+
// Do not crash if unable to add to cache (e.g., QuotaExceededError).
606+
// Rather, log a warning and proceed with execution.
607+
console.warn(`Unable to add response to browser cache: ${err}.`);
608+
});
609+
}
569610
}
570611
dispatchCallback(options.progress_callback, {
571612
status: 'done',
572613
name: path_or_repo_id,
573614
file: filename
574615
});
575616

576-
if (return_path) {
577-
if (response instanceof FileResponse) {
578-
return response.filePath;
579-
} else {
580-
const path = await cache.match(cacheKey);
581-
if (path instanceof FileResponse) {
582-
return path.filePath;
583-
} else {
584-
throw new Error("Unable to return path for response.");
585-
}
617+
if (result) {
618+
if (return_path) {
619+
throw new Error("Cannot return path in a browser environment.")
586620
}
621+
return result;
622+
}
623+
if (response instanceof FileResponse) {
624+
return response.filePath;
587625
}
588626

589-
return buffer;
627+
const path = await cache.match(cacheKey);
628+
if (path instanceof FileResponse) {
629+
return path.filePath;
630+
}
631+
throw new Error("Unable to return path for response.");
632+
590633
}
591634

592635
/**

0 commit comments

Comments
 (0)