@@ -78,6 +78,36 @@ const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type:
7878 }
7979} ;
8080
81+ const calcDataOffsetSnippet = ( dataRank : number , parallel : boolean ) =>
82+ `${
83+ dataRank === 1
84+ ? `
85+ let element_count_dim = uniforms.output_strides;
86+ let dim_value = uniforms.output_shape;`
87+ : `
88+ let element_count_dim = uniforms.output_strides[${ parallel ? 'i - indices_start' : 'i' } ];
89+ let dim_value = uniforms.output_shape[${ parallel ? 'i - indices_start' : 'i' } + uniforms.last_index_dimension];`
90+ }
91+
92+ if (index >= 0) {
93+ if (index >= i32(dim_value)) {
94+ index = i32(dim_value - 1);
95+ }
96+ } else {
97+ if (index < -i32(dim_value)) {
98+ index = 0;
99+ } else {
100+ index += i32(dim_value);
101+ }
102+ }
103+ data_offset += u32((u32(index) * element_count_dim));` ;
104+
105+ const updateElementsSnippet = ( attributes : ScatterNDAttributes , outputTypeValue : ReductionType , parallel : boolean ) =>
106+ `for (var i = 0u; i < uniforms.num_updates_elements; i++) {
107+ let value = updates[uniforms.num_updates_elements * ${ parallel ? 'global_idx' : 'idx' } + i];
108+ ${ atomicReductionSnippet ( attributes . reduction , 'output[data_offset + i]' , 'value' , outputTypeValue ) }
109+ }` ;
110+
81111const createScatterNDProgramInfo = ( inputs : readonly TensorView [ ] , attributes : ScatterNDAttributes ) : ProgramInfo => {
82112 const inputShape = inputs [ 0 ] . dims ;
83113 const indicesShape = inputs [ 1 ] . dims ;
@@ -87,6 +117,7 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
87117 const outputSize = Math . ceil ( ShapeUtil . size ( indicesShape ) / components ) ;
88118 const lastIndexDimension = indicesShape [ indicesShape . length - 1 ] ;
89119 const numUpdatesElements = ShapeUtil . sizeFromDimension ( inputShape , lastIndexDimension ) ;
120+ const numIndicesElements = ShapeUtil . sizeFromDimension ( indicesShape , 0 ) / lastIndexDimension ;
90121
91122 const programUniforms : ProgramUniform [ ] = [
92123 { type : DataType . uint32 , data : outputSize } ,
@@ -113,9 +144,8 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
113144 ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size' ) }
114145 var hasDuplicates = false;
115146 if (${ attributes . reduction === 'none' } ) {
116- let n = ${ ShapeUtil . size ( indicesShape ) } ;
117- for (var i = 0; i < n; i = i + 1) {
118- for (var j = i + 1; j < n; j = j + 1) {
147+ for (var i = 0; i < ${ numIndicesElements } ; i = i + 1) {
148+ for (var j = i + 1; j < ${ numIndicesElements } ; j = j + 1) {
119149 var index_i = i32(indices[i].x);
120150 var index_j = i32(indices[j].x);
121151 if (index_i == index_j) {
@@ -129,51 +159,31 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S
129159 }
130160 }
131161
132- var data_offset = 0u;
133- var indices_start = uniforms.last_index_dimension * global_idx;
134162 if (${ attributes . reduction === 'none' } && hasDuplicates) {
135163 if (global_idx != 0u) {
136164 return;
137165 }
138- indices_start = 0u;
139- }
140- let indices_end = indices_start + uniforms.last_index_dimension;
141- for (var i = indices_start; i < indices_end; i++) {
142- var index = i32(indices[i].x);
143- ${
144- inputs [ 0 ] . dims . length === 1
145- ? `
146- let element_count_dim = uniforms.output_strides;
147- let dim_value = uniforms.output_shape;`
148- : `
149- let element_count_dim = uniforms.output_strides[i - indices_start];
150- let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];`
151- }
152- if (index >= 0) {
153- if (index >= i32(dim_value)) {
154- index = i32(dim_value - 1);
155- }
156- } else {
157- if (index < -i32(dim_value)) {
158- index = 0;
159- } else {
160- index += i32(dim_value);
166+ // Process each index-update pair individually when duplicates exist
167+ for (var idx = 0u; idx < ${ numIndicesElements } u; idx++) {
168+ var data_offset = 0u;
169+ for (var i = 0u; i < uniforms.last_index_dimension; i++) {
170+ var index = i32(indices[idx * uniforms.last_index_dimension + i].x);
171+ ${ calcDataOffsetSnippet ( inputShape . length , false ) }
161172 }
173+ ${ updateElementsSnippet ( attributes , output . type . value as ReductionType , false ) }
162174 }
163- data_offset += u32((u32(index) * element_count_dim)) ;
175+ return ;
164176 }
165177
166- for (var i = 0u; i < uniforms.num_updates_elements; i++) {
167- let value = updates[uniforms.num_updates_elements * global_idx + i];
168- ${ atomicReductionSnippet (
169- attributes . reduction ,
170- 'output[data_offset + i]' ,
171- 'value' ,
172- output . type . value as ReductionType ,
173- ) }
178+ var data_offset = 0u;
179+ var indices_start = uniforms.last_index_dimension * global_idx;
180+ var indices_end = indices_start + uniforms.last_index_dimension;
181+ for (var i = indices_start; i < indices_end; i++) {
182+ var index = i32(indices[i].x);
183+ ${ calcDataOffsetSnippet ( inputShape . length , true ) }
174184 }
175-
176- }` ;
185+ ${ updateElementsSnippet ( attributes , output . type . value as ReductionType , true ) }
186+ }` ;
177187 } ;
178188 return {
179189 name : 'ScatterND' ,
0 commit comments