-
Notifications
You must be signed in to change notification settings - Fork 1k
πππ Transformers.js V3 πππ #545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
+31,351
β11,172
Merged
Changes from 47 commits
Commits
Show all changes
498 commits
Select commit
Hold shift + click to select a range
0dba266
Early dereferencing for performance boosts
xenova 5e4e20f
cleanup
xenova dd6af93
Move quantization logic to `quantize.py`
xenova 04af3d5
update deps
xenova 9128651
Fix q4 quantization
xenova 83cbb21
save q4 quantization
xenova eb61344
Add decode ASR test
xenova cec2400
Do not process last chunk unnecessarily
xenova c835b54
fp16 disable_shape_infer if model is too large
xenova 45cd8d4
Use `check_and_save_model` for saving fp16 model
xenova 88f3e44
Reorder functions
xenova 23440f0
formatting
xenova b411e9f
Remove debug log
xenova 04a334a
Fix q8 quantization for models > 2GB
xenova cd1ea69
correct attribute
xenova a167f6e
Fix `TextGenerationPipeline`
xenova ea73289
Fix pauses in whisper word-level timestamps
xenova 344af32
Formatting
xenova c305c38
Sort added tokens by length to avoid early partial matches
xenova d6f6fd4
Add new tokenizer test
xenova 1557b8d
Only finish with newline if running in Node.js
xenova 9ac7ceb
Consider token timestamps when selecting longest common sequence
xenova 79ed46e
Create whisper word-level timestamps demo
xenova 8da6886
cleanup
xenova d709bd0
Fallback to WASM if WebGPU not supported
xenova 9ef3a6d
Reload model for each quantization mode
xenova 9787b75
Update converstion script requirements
xenova 974f086
Separate IO and Quantization args
xenova d042868
Use `const` where possible
xenova 1b4d242
Add `InterruptableStoppingCriteria`
xenova 31101c8
`@xenova/transformers` -> `@huggingface/transformers`
xenova e84322b
Override semver version
xenova bd94334
Add support for pyannote models
xenova 3dbc633
Update README.md
xenova 858e55d
Add listed support for pyannote
xenova 8bf0349
Add pyannote example code
xenova c52618c
Support specifying `min_num_frames`
xenova 96f19b0
Support simultaneous instantiation of multiple inference sessions
xenova 4ad43e2
Support broadcasting encoder outputs over decoder inputs
xenova c6aeb4b
Fix test
xenova 6d3ea4b
fix bundler config for latest ORT
fs-eire 38a3bf6
Only check fp16 support for webgpu device
xenova 9df84c4
Remove default chat templates
xenova fc3d860
Add support for gemma2
xenova 939920d
Add gemma2 generation test
xenova 5bb93a0
Update gemma2 config mapping
xenova 72ec168
Prioritize high-performance adapter when possible
xenova 9068a53
Set defaults for `tools` and `documents` in `apply_chat_template`
xenova 824538b
bump `@huggingface/jinja` -> 0.3.0
xenova 836c0af
Add `apply_chat_template` default parameters unit test
xenova 487d8b2
Merge branch 'v3' into @huggingface/transformers
xenova 1f6e0e1
Add prettier
xenova 55494d1
prettier format config files
xenova 5a68461
remove incorrect comment
xenova 437cb34
Merge branch 'pr/864' into @huggingface/transformers
xenova 5a6c926
Update onnxruntime-web version
xenova b19251b
Update webpack.config.js
xenova 820c1e2
Fix copy path
xenova b0dab91
Run `npm ci`
xenova 86b9b62
Fix bundling
xenova 222b94e
Do not set `preferredOutputLocation` if we are proxying
xenova b326cc9
Merge branch 'v3' into @huggingface/transformers
xenova ca67092
Update `@webgpu/types`
xenova 42076fd
Update SAM example
xenova 48d3142
Use `??=` operator where possible
xenova 3b1a4fd
Fix commonjs usage
xenova 9a73b5e
Mark `onnxruntime-node` and `sharp` as externals
xenova 9951aa5
Move `externals` into config
xenova c04d37e
Downgrade to onnxruntime 1.18.0
xenova d32fe2b
Finalize module/commonjs build
xenova 1530d50
Separate web and node builds
xenova b4df0e2
[version] Update to 3.0.0-alpha.1
xenova ab59c51
Default to CDN-hosted .wasm files
xenova 866b219
[version] Update to 3.0.0-alpha.2
xenova 4a3398d
bump versions
xenova 8891a14
[version] Update to 3.0.0-alpha.3
xenova a315933
Merge branch 'improve-conversion-script' into v3
xenova 12569b8
Consolidate conversion and quantization script
xenova 83f5718
Downgrade `onnxconverter-common`
xenova 6fa5fa6
Link to types in exports
xenova 2f1b210
Update list of supported tasks
xenova 27bc55d
Fixed unit tests
xenova 23d1150
Update imports
xenova f9070dc
Bump versions to `3.0.0-alpha.4`
xenova c3494e1
[version] Update to 3.0.0-alpha.4
xenova 973fb0d
Fix "Default condition should be last one"
xenova 7376ecf
Bump versions
xenova 0a04bc0
[version] Update to 3.0.0-alpha.5
xenova e4603cd
Update next.js client-side demo
xenova ff1853c
Initial WebNN Support
ibelem 15574bc
Mark fs, path and url as external packages for node build
xenova 7282862
Move content type map outside of `FileResponse` object
xenova 22f7ced
Add GPU support for Node.js
xenova 1e319a4
Bump versions
xenova d278891
[version] Update to 3.0.0-alpha.6
xenova 3fefa17
Fix conflicts
ibelem fa6cc70
bump dependency versions
xenova 7fa5326
Add support for device auto-detection
xenova 4ec77c1
Fix default device selection
xenova 5799e30
Merge branch 'pr/ibelem/890-1' into v3
xenova 5b2cac2
Improve WebNN selection
xenova ad23c50
Skip token callback if `skip_prompt` is set
xenova 5b84b62
Bump versions
xenova bcf6a86
[version] Update to 3.0.0-alpha.7
xenova b97ed0d
bump versions
xenova c5b7083
[version] Update to 3.0.0-alpha.8
xenova cbeefde
bump versions
xenova 59600f2
[version] Update to 3.0.0-alpha.9
xenova b2e025a
Add support for Sapiens
xenova 8661d95
Update default ONNX env
xenova 57db34d
Fix types
xenova 1b7f978
Topologically sort fp16 nodes
xenova 45d1526
Add marian unit test
xenova b903757
Re-order imports
xenova 633976f
Fix `NoBadWordsLogitsProcessor`
xenova 24d8787
Update package.json
xenova 9412ec4
[jest] Disable coverage
xenova 08e7388
Bump versions
xenova d5a8f87
[version] Update to 3.0.0-alpha.10
xenova 7843ad0
Improve node/web interoperability
xenova bf093ae
Fix scripts/requirements.txt
xenova 9a5ee42
Bump versions
xenova 535cdfe
[version] Update to 3.0.0-alpha.11
xenova 4e1acf0
Add support for JAIS models (#906)
xenova 488548d
Add JAIS to README
xenova 13aed41
Fix node/web interop (again)
xenova 7655f81
Bump versions
xenova 1c7e226
[version] Update to 3.0.0-alpha.12
xenova ab6b28b
Set `SapiensForNormalEstimation` to encoder-only
xenova 66c05d5
Implement `sub` tensor operation
xenova 31e8b2a
Bump versions
xenova bf3f7d5
[version] Update to 3.0.0-alpha.13
xenova c025356
Improve typing for `wrap` helper function
xenova 7ebdaf2
Update `preferredOutputLocation` type
xenova 3b8ddcb
Make `wrap` type more generic
xenova a385c6e
Re-use `segmentation_data`
xenova 537e958
Fix `min` type
xenova bcb28b3
Add support for Hiera models
xenova d21c87c
Fix reused loop variable (closes #910)
xenova 1d281f6
Add logits processor test file
xenova ba0427f
Fix test imports
xenova 3bc3e86
Bump versions
xenova 0518960
[version] Update to 3.0.0-alpha.14
xenova 552cdea
Add another `bad_words` logits processor test (closes #913)
xenova 3422a8b
Add support for GroupViT
xenova 3599902
Add zero-shot-image-classification unit test
xenova 5892ee8
Add maskformer model definitions
xenova c4dac77
Support universal image segmentation in `image-segmentation` pipeline
xenova f0c47be
Add support for PVT models
xenova d80d3a4
Add `post_process_instance_segmentation` function template
xenova 844099d
Add `library_name` option to convert.py
xenova ba5d725
Wrap onnxslim with try block
xenova b3691c8
Use const where possible
xenova dcf117f
Use const where possible (again)
xenova 9af026c
Create `MaskFormerFeatureExtractor`
xenova 0f8200c
Add support for MaskFormer
xenova e278c5e
Improve tool-use chat template detection
xenova 83fa58f
Add object detection pipeline unit test
xenova 86d6da4
Add support for ViTMSN and VitMAE
xenova 93b25fb
Bump ORT versions
xenova 2f680ee
Create `get_chat_template` helper function
xenova 2f9b2ed
Fix CI
xenova deec350
Run prettier on `tests/**`
xenova 48fa226
move certain tests to utils subfolder
xenova a10828f
Bump onnxruntime-web version
xenova ba58ea2
Bump `onnxruntime==1.19.2` in scripts/requirements.txt
xenova 4f17e95
Merge branch 'main' into v3
xenova c40a151
Merge branch 'main' into v3
xenova 30315b2
Sort `this.added_tokens` before creating regex (`.toSorted` is not avβ¦
xenova d7df575
Rather make a copy of `this.added_tokens`
xenova a519379
Fix `.tokenize` with `fuse_unk=true`
xenova 89ddccf
Add blenderbot tokenizer tests
xenova 36ad144
Add t5 tokenizer tests
xenova 4765dd6
Add falcon tokenizer tests
xenova fd8b9a2
Run prettier
xenova 710816e
Add ESM tokenizer tests
xenova 0d3cd30
Run unit tests in parallel
xenova cc258c2
Fix `fuse_unk` for tokenizers with `byte_fallback=true` but no byte fβ¦
xenova 4798755
Add llama tokenizer unit tests
xenova c6c3ae1
Update emoji test string names
xenova 79a7409
Move whisper-specific unit tests to subfolder
xenova 1a38804
Code formatting
xenova dabe6ae
Bump versions
xenova 54f1f21
[version] Update to 3.0.0-alpha.15
xenova a912d79
Add emoji tokenizer test cases for LlamaTokenizer
xenova 969d10e
Attempt to fix encoder-decoder memory leak
xenova 072cbbc
Remove unused code
xenova 14b4bd4
Fix BertNormalizer (strip `Mn` unicode characters)
xenova 6797771
Handle ZERO WIDTH JOINER (U+200D) characters
xenova f148afd
Add more spm normalization characters
xenova ca4b5b9
Add emoji unit tests for bert/t5
xenova 113c81e
[WebNN] Add support for specifying `free_dimension_overrides` in config
xenova 9005acc
Log warning if webnn is selected by `free_dimension_overrides` is notβ¦
xenova 682c7d0
Fix unigram for multi-byte tokens
xenova 4a31e54
Add gemma tokenizer tests
xenova 7a16065
Allow user to specify device and dtype in config.json
xenova 4c1d21b
Update dependency versions
xenova 3c6a95a
Bump versions
xenova ac391d2
[version] Update to 3.0.0-alpha.16
xenova d30d3b7
Add CLIP tokenizer unit tests
xenova e089ef4
Add more tokenizer tests
xenova 2c9e271
Bump onnxruntime-web version
xenova ee1e32a
Bump versions
xenova f41e995
[version] Update to 3.0.0-alpha.17
xenova 9a42cf3
Add support for new `tokenizers>=0.2.0` BPE serialization format
xenova f534b35
Bump onnxruntime-web version
xenova 0c8b1af
Bump versions
xenova 2ca4178
[version] Update to 3.0.0-alpha.18
xenova a82e7ef
Keep encoder outputs on GPU
xenova c37a38c
Update whisper-webgpu demo dependencies
xenova e1c4fc6
Bump versions
xenova fe51609
[version] Update to 3.0.0-alpha.19
xenova b518866
Support to load ONNX APIs based on JS runtime (#947)
kallebysantos 95c8cc5
Allow specification of `use_external_data_format` in custom config
xenova 03eb77b
Update deberta unit tests
xenova c61a76b
Update roberta tokenizer tests
xenova 32d8df4
Support inferringunigram tokenizer type
xenova 6505abb
Reuse tokenizer tests for original t5-small
xenova 9619218
Remove redundant null coalesce
xenova 52c4ce7
Enable unit test coverage reports
xenova 12edaf0
Use `PROBLEMATIC_REGEX_MAP` for bloom tokenizer
xenova 5e7e82b
Improve tokenizer unit tests
xenova 795a61a
Update tokenizer unit tests
xenova 77ebe0d
Remove unused code
xenova 56eda3b
Add m2m_100 tokenizer unit tests
xenova 2040ad5
Add m2m translation pipeline unit test
xenova 8718c17
Add support for Depth Pro models
xenova a32efa3
Add whisper turbo alignment heads
xenova 8b0d330
Remove in-library list of supported models
xenova cf3f5c3
Bump versions
xenova 86fe175
[version] Update to 3.0.0-alpha.20
xenova 1c78278
Add function to map tensor data array.
BritishWerewolf a5e0210
Merge branch 'main' into v3
xenova 9f8fac0
Optimise loop to reduce calls to `this`
BritishWerewolf 1c43e3f
Merge branch 'pr/966' into v3
xenova 7a0f77c
Add back tensor map test
xenova da03a0a
Add support for granite models
xenova 37effa3
Allow multiple optional configs to be passed (+ reduce code duplication)
xenova f21b36e
Bump dependencies
xenova d26a663
Bump versions
xenova c337c3b
[version] Update to 3.0.0-alpha.21
xenova 92d0dc6
Add support for per-dtype `kv_cache_dtype`
xenova ea03bf5
Add text streamer unit test
xenova 27a033f
Bump ORT web version
xenova 19277ea
Bump versions
xenova 90a7490
[version] Update to 3.0.0-alpha.22
xenova 38773ea
Update repo name to `@huggingface/transformers.js`
xenova 832b5b7
Update tested node versions
xenova b871c08
Bump versions
xenova 7a58d6e
[version] Update to 3.0.0
xenova File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -101,7 +101,7 @@ npm i @xenova/transformers | |
| Alternatively, you can use it in vanilla JS, without any bundler, by using a CDN or static hosting. For example, using [ES Modules](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Modules), you can import the library with: | ||
| ```html | ||
| <script type="module"> | ||
| import { pipeline } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.16.0'; | ||
| import { pipeline } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@3.0.0-alpha.0'; | ||
| </script> | ||
| ``` | ||
|
|
||
|
|
@@ -134,8 +134,7 @@ Check out the Transformers.js [template](https://huggingface.co/new-space?templa | |
|
|
||
|
|
||
|
|
||
| By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/[email protected]/dist/), which should work out-of-the-box. You can customize this as follows: | ||
|
|
||
| By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/[email protected]/dist/), which should work out-of-the-box. You can customize this as follows: | ||
|
|
||
| ### Settings | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,6 @@ | ||
|
|
||
|
|
||
| By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/[email protected]/dist/), which should work out-of-the-box. You can customize this as follows: | ||
|
|
||
| By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/[email protected]/dist/), which should work out-of-the-box. You can customize this as follows: | ||
|
|
||
| ### Settings | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # Logs | ||
| logs | ||
| *.log | ||
| npm-debug.log* | ||
| yarn-debug.log* | ||
| yarn-error.log* | ||
| pnpm-debug.log* | ||
| lerna-debug.log* | ||
|
|
||
| node_modules | ||
| dist | ||
| dist-ssr | ||
| *.local | ||
|
|
||
| # Editor directories and files | ||
| .vscode/* | ||
| !.vscode/extensions.json | ||
| .idea | ||
| .DS_Store | ||
| *.suo | ||
| *.ntvs* | ||
| *.njsproj | ||
| *.sln | ||
| *.sw? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| <!DOCTYPE html> | ||
| <html lang="en"> | ||
|
|
||
| <head> | ||
| <meta charset="UTF-8" /> | ||
| <meta name="viewport" content="width=device-width, initial-scale=1.0" /> | ||
| <title>Transformers.js | WebGPU Benchmark</title> | ||
| </head> | ||
|
|
||
| <body> | ||
| <h1> | ||
| <a href="http://github.com/xenova/transformers.js" target="_blank">π€ Transformers.js</a> WebGPU Benchmark | ||
| </h1> | ||
| <p> | ||
| This benchmark measures the execution time of <a | ||
| href="https://huggingface.co/Xenova/all-MiniLM-L6-v2" target="_blank">Xenova/all-MiniLM-L6-v2</a> (bert-based embedding model) | ||
| using the WASM and WebGPU execution providers across different batch sizes. | ||
| </p> | ||
| <div id="chart-container"> | ||
| <canvas id="chart"></canvas> | ||
| </div> | ||
| <div> | ||
| <button id="start" disabled>Start Benchmark</button> | ||
| <button id="stop" disabled>Stop Benchmark</button> | ||
| </div> | ||
| <label id="status"></label> | ||
| <details open> | ||
| <summary>Options</summary> | ||
| <div> | ||
| <label>Batch sizes</label> | ||
| <input id="batch-sizes" value="1, 2, 4, 8, 16, 32" /> | ||
| </div> | ||
| <div> | ||
| <label>Sequence length</label> | ||
| <input id="sequence-length" type="number" min="1" max="512" value="512" /> | ||
| </div> | ||
| <div> | ||
| <input id="x-scale" type="checkbox" /> | ||
| <label>Log scale (x)</label> | ||
| <input id="y-scale" type="checkbox" /> | ||
| <label>Log scale (y)</label> | ||
| </div> | ||
| </details> | ||
| <script type="module" src="/main.js"></script> | ||
| </body> | ||
| </html> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,255 @@ | ||
| import './style.css'; | ||
| import { env, AutoModel, ones } from '@xenova/transformers'; | ||
| import Chart from 'chart.js/auto'; | ||
|
|
||
| // Throw an error if WebGPU is not supported | ||
| if (!navigator.gpu) { | ||
| const err = 'WebGPU is not supported by this browser.'; | ||
| alert(err) | ||
| throw Error(err); | ||
| } | ||
|
|
||
| // Proxy the WASM backend to prevent the UI from freezing | ||
| env.backends.onnx.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/[email protected]/dist/'; | ||
| env.backends.onnx.wasm.numThreads = 1; | ||
|
|
||
| // Reference the elements that we will need | ||
| const ctx = document.getElementById('chart'); | ||
| const batchSizes = document.getElementById('batch-sizes'); | ||
| const xscale = document.getElementById('x-scale'); | ||
| const yscale = document.getElementById('y-scale'); | ||
| const sequenceLength = document.getElementById('sequence-length'); | ||
| const status = document.getElementById('status'); | ||
| const start = document.getElementById('start'); | ||
| const stop = document.getElementById('stop'); | ||
|
|
||
| // Benchmark settings | ||
| const NUM_WARMUP_STEPS = 3; | ||
| const QUANTIZED = false; | ||
| const MODEL_ID = 'Xenova/all-MiniLM-L6-v2'; | ||
|
|
||
| // Chart configuration | ||
| const config = { | ||
| type: 'line', | ||
| data: { | ||
| labels: [], | ||
| datasets: [{ | ||
| label: 'WASM', | ||
| data: [], | ||
| borderColor: 'red', | ||
| backgroundColor: 'rgba(255, 0, 0, 0.5)', | ||
| }, { | ||
| label: 'WebGPU', | ||
| data: [], | ||
| borderColor: 'blue', | ||
| backgroundColor: 'rgba(0, 0, 255, 0.5)', | ||
| }] | ||
| }, | ||
| options: { | ||
| responsive: true, | ||
| maintainAspectRatio: false, | ||
| plugins: { | ||
| legend: { | ||
| position: 'top', | ||
| }, | ||
| }, | ||
| scales: { | ||
| x: { | ||
| title: { | ||
| display: true, | ||
| text: 'Batch size', | ||
| }, | ||
| min: 1, | ||
| }, | ||
| y: { | ||
| title: { | ||
| display: true, | ||
| text: 'Time (ms)', | ||
| }, | ||
| } | ||
| } | ||
| }, | ||
| }; | ||
|
|
||
| const toggleScale = (chart, axis, enabled) => { | ||
| chart.options.scales[axis].type = enabled ? 'logarithmic' : 'linear'; | ||
| chart.update(); | ||
| } | ||
|
|
||
| xscale.addEventListener('change', () => toggleScale(chart, 'x', xscale.checked)); | ||
| yscale.addEventListener('change', () => toggleScale(chart, 'y', yscale.checked)); | ||
|
|
||
| const chart = new Chart(ctx, config); | ||
|
|
||
| status.textContent = 'Loading model...'; | ||
|
|
||
| let model_CPU; | ||
| try { | ||
| model_CPU = await AutoModel.from_pretrained(MODEL_ID, { | ||
| quantized: QUANTIZED, | ||
| device: 'webgpu' | ||
| }); | ||
| } catch (err) { | ||
| status.textContent = err.message; | ||
| alert(err.message) | ||
| throw err; | ||
| } | ||
|
|
||
| let model_GPU; | ||
| try { | ||
| model_GPU = await AutoModel.from_pretrained(MODEL_ID, { | ||
| quantized: QUANTIZED, | ||
| session_options: { | ||
| executionProviders: ['webgpu'] | ||
| } | ||
| }); | ||
| } catch (err) { | ||
| status.textContent = err.message; | ||
| alert(err.message) | ||
| throw err; | ||
| } | ||
|
|
||
| let adapterInfo; | ||
| try { | ||
| // Shouldn't fail since the WebGPU model has loaded successfully | ||
| const adapter = await navigator.gpu.requestAdapter(); | ||
| adapterInfo = await adapter.requestAdapterInfo(); | ||
| } catch (err) { | ||
| adapterInfo = {}; | ||
| } | ||
|
|
||
| status.textContent = 'Ready'; | ||
|
|
||
| let interrupted = false; | ||
| start.addEventListener('click', async () => { | ||
| start.disabled = true; | ||
| stop.disabled = false; | ||
| interrupted = false; | ||
|
|
||
| // Reset | ||
| chart.data.labels = []; | ||
| for (let i = 0; i < chart.data.datasets; ++i) { | ||
| chart.data.datasets[i].data = []; | ||
| } | ||
| chart.update(); | ||
|
|
||
| const seqLength = parseInt(sequenceLength.value); | ||
|
|
||
| status.textContent = 'Warming up...'; | ||
|
|
||
| const generateDummyInputs = (batch_size) => { | ||
|
|
||
| const inputs = ones([batch_size, seqLength]); | ||
|
|
||
| const model_inputs = { | ||
| input_ids: inputs, | ||
| attention_mask: inputs, | ||
| } | ||
| return model_inputs; | ||
| } | ||
|
|
||
| // Warm up: This is important for the WebGPU execution provider, which compiles the shaders on first load | ||
| for (let i = 0; i < NUM_WARMUP_STEPS; ++i) { | ||
| const model_inputs = generateDummyInputs(1); | ||
| await model_CPU(model_inputs); | ||
| await model_GPU(model_inputs); | ||
| } | ||
|
|
||
| status.textContent = 'Running benchmark...'; | ||
|
|
||
| const batch_sizes = batchSizes.value.split(',').map(x => parseInt(x)).filter(x => x); | ||
|
|
||
| for (const batch_size of batch_sizes) { | ||
| if (interrupted) break; | ||
|
|
||
| const model_inputs = generateDummyInputs(batch_size); | ||
|
|
||
| let wasmTime; | ||
| { // Run WASM | ||
| const start = performance.now(); | ||
| await model_CPU(model_inputs); | ||
| const end = performance.now(); | ||
| wasmTime = end - start; | ||
| } | ||
|
|
||
| let webGPUTime; | ||
| { // Run WebGPU | ||
| const start = performance.now(); | ||
| await model_GPU(model_inputs); | ||
| const end = performance.now(); | ||
| webGPUTime = end - start; | ||
| } | ||
| chart.data.labels.push(batch_size); | ||
| chart.data.datasets[0].data.push(wasmTime); | ||
| chart.data.datasets[1].data.push(webGPUTime); | ||
| chart.update(); | ||
| } | ||
|
|
||
| // Calculate max speedup: | ||
| if (chart.data.labels.length === 0) return; | ||
|
|
||
| const table = generateResultsTable(chart.data, seqLength); | ||
|
|
||
| const speedup = chart.data.datasets[0].data.at(-1) / chart.data.datasets[1].data.at(-1); | ||
| const roundedSpeedup = speedup.toFixed(2); | ||
| const params = new URLSearchParams({ | ||
| title: `β‘ WebGPU Benchmark Results (${roundedSpeedup}x speedup)`, | ||
| description: table.outerHTML, | ||
| }); | ||
|
|
||
| const paramsStr = params.toString(); | ||
| status.innerHTML = `β‘ Done! WebGPU is <strong>${roundedSpeedup}x</strong> faster! <a href="https://huggingface.co/spaces/Xenova/webgpu-embedding-benchmark/discussions/new?${paramsStr}" target="_blank">Share results</a>`; | ||
| start.disabled = false; | ||
| }); | ||
| start.disabled = false; | ||
|
|
||
| stop.addEventListener('click', () => { | ||
| status.textContent = 'Stopping...'; | ||
| interrupted = true; | ||
| stop.disabled = true; | ||
| }); | ||
|
|
||
| function generateResultsTable(data, sequence_length) { | ||
| const datasets = data.datasets.map(d => d.data); | ||
| const batch_sizes = data.labels; | ||
|
|
||
| const container = document.createElement('div'); | ||
|
|
||
| const table = document.createElement('table'); | ||
| const thead = table.createTHead(); | ||
| const tbody = table.createTBody(); | ||
|
|
||
| // Add header row | ||
| const headerRow = thead.insertRow(); | ||
| headerRow.insertCell().textContent = 'Batch Size'; | ||
| headerRow.insertCell().textContent = `WASM (ms)`; | ||
| headerRow.insertCell().textContent = `WebGPU (ms)`; | ||
|
|
||
| // Add data rows | ||
| batch_sizes.forEach((batchSize, rowIndex) => { | ||
| const row = tbody.insertRow(); | ||
| row.insertCell().textContent = batchSize; | ||
| datasets.forEach(dataset => { | ||
| row.insertCell().textContent = dataset[rowIndex].toFixed(2); | ||
| }); | ||
| }); | ||
|
|
||
| container.appendChild(table); | ||
|
|
||
| const createBulletPoint = (text) => { | ||
| const li = document.createElement('li'); | ||
| li.textContent = text; | ||
| return li; | ||
| } | ||
|
|
||
| // Add other information | ||
| const info = document.createElement('ul'); | ||
| info.appendChild(createBulletPoint(`Model: ${MODEL_ID}`)); | ||
| info.appendChild(createBulletPoint(`Quantized: ${QUANTIZED}`)); | ||
| info.appendChild(createBulletPoint(`Sequence length: ${sequence_length}`)); | ||
| info.appendChild(createBulletPoint(`Browser: ${navigator.userAgent}`)); | ||
| info.appendChild(createBulletPoint(`GPU: vendor=${adapterInfo.vendor}, architecture=${adapterInfo.architecture}, device=${adapterInfo.device}, description=${adapterInfo.description}`)); | ||
| container.appendChild(info); | ||
|
|
||
| return container; | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.