Skip to content

Commit 0e45e31

Browse files
authored
Merge pull request #6 from link-assistant/issue-5-9044dbde1c45
fix: resolve browser inference 'unexpected rank' error
2 parents e76bf60 + e762061 commit 0e45e31

File tree

9 files changed

+5164
-10
lines changed

9 files changed

+5164
-10
lines changed

.github/workflows/e2e.yml

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
name: E2E Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
branches:
9+
- main
10+
workflow_dispatch:
11+
12+
concurrency:
13+
group: e2e-${{ github.ref }}
14+
cancel-in-progress: true
15+
16+
env:
17+
CARGO_TERM_COLOR: always
18+
19+
jobs:
20+
e2e-tests:
21+
name: Browser E2E Tests
22+
runs-on: ubuntu-latest
23+
timeout-minutes: 30
24+
steps:
25+
- uses: actions/checkout@v4
26+
27+
- name: Setup Rust
28+
uses: dtolnay/rust-toolchain@stable
29+
with:
30+
targets: wasm32-unknown-unknown
31+
32+
- name: Install wasm-pack
33+
run: cargo install wasm-pack
34+
35+
- name: Setup Node.js
36+
uses: actions/setup-node@v4
37+
with:
38+
node-version: '20.x'
39+
40+
- name: Cache cargo registry
41+
uses: actions/cache@v4
42+
with:
43+
path: |
44+
~/.cargo/registry
45+
~/.cargo/git
46+
target
47+
wasm/target
48+
key: ${{ runner.os }}-cargo-wasm-${{ hashFiles('**/Cargo.lock') }}
49+
restore-keys: |
50+
${{ runner.os }}-cargo-wasm-
51+
52+
- name: Cache npm dependencies
53+
uses: actions/cache@v4
54+
with:
55+
path: web/node_modules
56+
key: ${{ runner.os }}-npm-${{ hashFiles('web/package-lock.json') }}
57+
restore-keys: |
58+
${{ runner.os }}-npm-
59+
60+
- name: Build WASM package
61+
env:
62+
RUSTFLAGS: '--cfg getrandom_backend="wasm_js" -C target-feature=+bulk-memory,+mutable-globals,+simd128'
63+
run: |
64+
cd wasm
65+
wasm-pack build --target web --out-dir ../web/src/pkg
66+
67+
- name: Install npm dependencies
68+
run: |
69+
cd web
70+
npm install
71+
72+
- name: Install Playwright browsers
73+
run: |
74+
cd web
75+
npx playwright install chromium --with-deps
76+
77+
- name: Build web application
78+
run: |
79+
cd web
80+
npm run build
81+
82+
- name: Run E2E tests
83+
run: |
84+
cd web
85+
npm run test:e2e
86+
env:
87+
CI: true
88+
89+
- name: Upload Playwright report
90+
uses: actions/upload-artifact@v4
91+
if: always()
92+
with:
93+
name: playwright-report
94+
path: web/playwright-report/
95+
retention-days: 7
96+
97+
- name: Upload test results
98+
uses: actions/upload-artifact@v4
99+
if: failure()
100+
with:
101+
name: test-results
102+
path: web/test-results/
103+
retention-days: 7

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ web/src/pkg/
6363
.npm
6464
npm-debug.log*
6565

66+
# Playwright
67+
web/playwright-report/
68+
web/test-results/
69+
web/playwright/.cache/
70+
6671
# WASM build artifacts
6772
wasm/pkg/
6873
wasm/target/
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
### Fixed
2+
3+
- Fixed browser inference "Repeat penalty failed: unexpected rank" error that occurred when generating text. The bug was caused by incorrectly attempting to index into the logits tensor after the Llama model's forward pass, which already extracts the last position internally.
4+
5+
### Added
6+
7+
- Added Playwright e2e tests to verify browser inference works correctly
8+
- Added GitHub Actions workflow for running e2e tests on PRs and main branch

wasm/src/lib.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -235,18 +235,12 @@ pub async fn generate(
235235
.map_err(|e| JsValue::from_str(&format!("Forward pass failed: {}", e)))?;
236236

237237
// Get logits for next token prediction
238+
// The Llama model already extracts the last position internally,
239+
// so the output shape is [batch_size, vocab_size], not [batch_size, seq_len, vocab_size]
238240
let logits = logits
239241
.squeeze(0)
240242
.map_err(|e| JsValue::from_str(&format!("Squeeze failed: {}", e)))?;
241243

242-
let seq_len = logits
243-
.dim(0)
244-
.map_err(|e| JsValue::from_str(&format!("Failed to get dim: {}", e)))?;
245-
246-
let logits = logits
247-
.get(seq_len - 1)
248-
.map_err(|e| JsValue::from_str(&format!("Get logits failed: {}", e)))?;
249-
250244
// Apply repeat penalty
251245
let logits = if params.repeat_penalty != 1.0 {
252246
let start_at = all_tokens.len().saturating_sub(params.repeat_last_n);

web/e2e/inference.spec.ts

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import { test, expect } from '@playwright/test';
2+
3+
/**
4+
* E2E tests for SmolLM2 browser inference.
5+
*
6+
* These tests verify that the WASM-based language model can:
7+
* 1. Load successfully in the browser
8+
* 2. Generate text responses without errors
9+
* 3. Stream tokens back to the UI
10+
*
11+
* Note: These tests require significant time due to:
12+
* - Model download (~270MB)
13+
* - WASM compilation
14+
* - Inference computation
15+
*/
16+
17+
test.describe('SmolLM2 Browser Inference', () => {
18+
// Run tests serially since they share model state
19+
test.describe.configure({ mode: 'serial' });
20+
21+
test('should display initial UI correctly', async ({ page }) => {
22+
await page.goto('/');
23+
24+
// Check header
25+
await expect(page.getByRole('heading', { name: 'SmolLM2 in Browser' })).toBeVisible();
26+
await expect(
27+
page.getByText('AI language model running entirely on your device via WebAssembly')
28+
).toBeVisible();
29+
30+
// Check load button
31+
await expect(page.getByRole('button', { name: /Load Model/i })).toBeVisible();
32+
await expect(page.getByRole('button', { name: /Load Model/i })).toBeEnabled();
33+
34+
// Check initial message
35+
await expect(
36+
page.getByText(/Hello! I'm SmolLM2, a small language model running entirely in your browser/)
37+
).toBeVisible();
38+
39+
// Check footer info
40+
await expect(page.getByText(/No data sent to servers/)).toBeVisible();
41+
42+
// Check initial status - worker sends 'Worker initialized' on startup
43+
await expect(page.getByText('Worker initialized')).toBeVisible();
44+
});
45+
46+
test('should load model successfully', async ({ page }) => {
47+
await page.goto('/');
48+
49+
// Verify initial state
50+
await expect(page.getByRole('button', { name: /Load Model/i })).toBeVisible();
51+
52+
// Click load button
53+
await page.getByRole('button', { name: /Load Model/i }).click();
54+
55+
// Should show loading status
56+
await expect(page.getByText(/Initializing|Downloading|Loading/i)).toBeVisible({
57+
timeout: 10000,
58+
});
59+
60+
// Wait for model to be ready (this can take several minutes)
61+
await expect(page.getByText('Model ready')).toBeVisible({
62+
timeout: 5 * 60 * 1000, // 5 minutes
63+
});
64+
65+
// Load button should be gone
66+
await expect(page.getByRole('button', { name: /Load Model/i })).not.toBeVisible();
67+
68+
// Message input should be enabled
69+
await expect(page.locator('.cs-message-input__content-editor')).toBeEnabled();
70+
});
71+
72+
test('should generate text response without errors', async ({ page }) => {
73+
await page.goto('/');
74+
75+
// Listen for console errors from the start
76+
const consoleErrors: string[] = [];
77+
page.on('console', (msg) => {
78+
if (msg.type() === 'error') {
79+
consoleErrors.push(msg.text());
80+
}
81+
});
82+
83+
// Load the model first
84+
await page.getByRole('button', { name: /Load Model/i }).click();
85+
await expect(page.getByText('Model ready')).toBeVisible({
86+
timeout: 5 * 60 * 1000,
87+
});
88+
89+
// Send a message
90+
const messageInput = page.locator('.cs-message-input__content-editor');
91+
await messageInput.fill('Hello');
92+
await messageInput.press('Enter');
93+
94+
// Should show user message
95+
await expect(page.getByText('Hello').first()).toBeVisible();
96+
97+
// Wait for generation to complete (typing indicator should appear then disappear)
98+
// The response should appear within 2 minutes
99+
await expect(page.getByText('SmolLM2 is thinking...')).toBeVisible({ timeout: 10000 });
100+
101+
// Wait for typing indicator to disappear (generation complete)
102+
await expect(page.getByText('SmolLM2 is thinking...')).not.toBeVisible({
103+
timeout: 2 * 60 * 1000,
104+
});
105+
106+
// Check for the critical error that was reported in issue #5
107+
const repeatPenaltyError = consoleErrors.find((e) =>
108+
e.includes('Repeat penalty failed: unexpected rank')
109+
);
110+
expect(repeatPenaltyError).toBeUndefined();
111+
112+
// There should be no error status
113+
await expect(page.getByText(/Error:/i)).not.toBeVisible();
114+
115+
// Status should still be "Model ready" (not error state)
116+
await expect(page.getByText('Model ready')).toBeVisible();
117+
});
118+
119+
test('should stream tokens to the UI', async ({ page }) => {
120+
await page.goto('/');
121+
122+
// Load the model first
123+
await page.getByRole('button', { name: /Load Model/i }).click();
124+
await expect(page.getByText('Model ready')).toBeVisible({
125+
timeout: 5 * 60 * 1000,
126+
});
127+
128+
// Send a message
129+
const messageInput = page.locator('.cs-message-input__content-editor');
130+
await messageInput.fill('Count from 1 to 5');
131+
await messageInput.press('Enter');
132+
133+
// Wait for generation to complete
134+
await expect(page.getByText('SmolLM2 is thinking...')).toBeVisible({ timeout: 10000 });
135+
await expect(page.getByText('SmolLM2 is thinking...')).not.toBeVisible({
136+
timeout: 2 * 60 * 1000,
137+
});
138+
139+
// There should be multiple AI response regions (initial greeting + new response)
140+
const aiMessages = page.locator('[class*="cs-message--incoming"]');
141+
await expect(aiMessages).toHaveCount(2, { timeout: 5000 });
142+
});
143+
});
144+
145+
test.describe('Error Handling', () => {
146+
test('should handle model loading gracefully', async ({ page }) => {
147+
await page.goto('/');
148+
149+
// Click load button
150+
await page.getByRole('button', { name: /Load Model/i }).click();
151+
152+
// Should not crash immediately
153+
await expect(page.getByText(/Initializing|Downloading/i)).toBeVisible({
154+
timeout: 10000,
155+
});
156+
157+
// Page should remain responsive
158+
await expect(page.getByRole('heading', { name: 'SmolLM2 in Browser' })).toBeVisible();
159+
});
160+
});

0 commit comments

Comments
 (0)