Skip to content

Commit 33d5ea3

Browse files
authored
[js/webgpu] fixes for fp16 attention (#20440)
1 parent 80213a9 commit 33d5ea3

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

js/web/lib/wasm/jsep/webgpu/ops/attention.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
264264
let local_offset = local_idx * uniforms.elements_per_thread;
265265
let offset = workgroup_id.x * uniforms.d_comp + local_offset;
266266
267-
var thread_max_vector = ${inputHelper.type.value}(-3.402823e+38f);
267+
var thread_max_vector = ${f32Type}(-3.402823e+38f);
268268
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
269269
thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector);
270270
}
@@ -282,12 +282,12 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
282282
})()};
283283
workgroupBarrier();
284284
285-
var max_value: f32 = -3.402823e+38f;
285+
var max_value = -3.402823e+38f;
286286
for (var i = 0u; i < ${WG}; i++) {
287287
max_value = max(thread_max[i], max_value);
288288
}
289289
290-
var sum_vector = ${inputHelper.type.value}(${0});
290+
var sum_vector = ${f32Type}(${0});
291291
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
292292
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
293293
}

js/web/lib/wasm/jsep/webgpu/ops/common.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ export const castToF32 = (dataType: string, components: number, value: string) =
313313
return `f32(${value})`;
314314
}
315315

316-
return `vec${components}f32(${value})`;
316+
return `vec${components}<f32>(${value})`;
317317
};
318318

319319
/**

0 commit comments

Comments
 (0)