@@ -140,20 +140,28 @@ function Base.getindex(a::TracedRArray{T,1}, indices::CartesianIndex{1}) where {
140
140
return _getindex_cartesian (a, indices)
141
141
end
142
142
143
- _isone (x) = isone (x)
144
- _isone (:: CartesianIndex ) = false
145
-
146
- __contiguous_indices (:: Base.LogicalIndex ) = false
147
- __contiguous_indices (x) = all (_isone, diff (x))
148
-
149
143
function _getindex_linear (a:: TracedRArray{T,N} , indices:: AbstractArray ) where {T,N}
150
- if ! (indices isa Reactant. TracedType) && __contiguous_indices (vec (indices))
151
- a_flat = materialize_traced_array (vec (a))
152
- indices_flat = vec (indices)
153
- return Ops. reshape (
154
- Ops. dynamic_slice (a_flat, [first (indices_flat)], [length (indices_flat)]),
155
- collect (size (indices)),
156
- )
144
+ if ! (indices isa Reactant. TracedType)
145
+ if length (indices) == 1 && first (indices) isa CartesianIndex
146
+ # fast-path else we will end up with a gather
147
+ return TracedUtils. broadcast_to_size (
148
+ @allowscalar (_getindex_cartesian (a, first (indices))), (1 ,)
149
+ )
150
+ end
151
+ stride = TracedUtils. _get_slice_stride (vec (indices))
152
+ if stride > 0
153
+ a_flat = materialize_traced_array (vec (a))
154
+ indices_flat = vec (indices)
155
+ return Ops. reshape (
156
+ Ops. slice (
157
+ a_flat,
158
+ Int64[first (indices_flat)],
159
+ Int64[last (indices_flat)];
160
+ strides= Int64[stride],
161
+ ),
162
+ collect (Int64, size (indices)),
163
+ )
164
+ end
157
165
end
158
166
159
167
if ! (indices isa TracedRArray)
@@ -180,14 +188,21 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
180
188
indices = Base. to_indices (a, indices)
181
189
182
190
use_gather_getindex = false
191
+ use_dynamic_slice = false
192
+ strides = Int64[]
183
193
for idxs in indices
184
- idxs isa Number && continue
194
+ if idxs isa Number
195
+ idxs isa TracedRNumber && (use_dynamic_slice = true )
196
+ push! (strides, 1 )
197
+ continue
198
+ end
185
199
if idxs isa Reactant. TracedType
186
200
use_gather_getindex = true
187
201
break
188
202
end
189
- contiguous = __contiguous_indices (vec (idxs))
190
- if typeof (contiguous) <: Bool && ! contiguous
203
+ stride = TracedUtils. _get_slice_stride (vec (idxs))
204
+ push! (strides, stride)
205
+ if stride ≤ 0 || (use_dynamic_slice && stride != 1 )
191
206
use_gather_getindex = true
192
207
break
193
208
end
@@ -200,18 +215,37 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
200
215
error (" Boolean indexing with TracedRArrays isn't fully supported yet." )
201
216
end
202
217
203
- indices, integer_indices, result_size, preddim_result_size, _ = TracedUtils. traced_indices (
204
- indices...
205
- )
206
- res = Ops. reshape (
207
- Ops. gather_getindex (a, generate_index_list (indices... )), preddim_result_size
218
+ gather_dims = TracedUtils. indices_to_gather_dims (indices... )
219
+
220
+ return Ops. reshape (
221
+ Ops. transpose (
222
+ Ops. reshape (
223
+ Ops. gather (
224
+ a,
225
+ gather_dims. start_indices;
226
+ gather_dims. offset_dims,
227
+ gather_dims. collapsed_slice_dims,
228
+ operand_batching_dims= Int64[],
229
+ start_indices_batching_dims= Int64[],
230
+ gather_dims. start_index_map,
231
+ gather_dims. index_vector_dim,
232
+ gather_dims. slice_sizes,
233
+ ),
234
+ gather_dims. gather_reshape_shape,
235
+ ),
236
+ gather_dims. permutation,
237
+ ),
238
+ gather_dims. result_shape,
208
239
)
209
- isempty (integer_indices) ||
210
- (res = materialize_traced_array (dropdims (res; dims= integer_indices)))
211
- return Ops. reshape (res, result_size)
212
240
end
213
241
214
- x = Ops. dynamic_slice (a, [first .(indices)... ], [length .(indices)... ])
242
+ if use_dynamic_slice
243
+ @assert all (isone, strides) " This should not happen, please report a bug"
244
+ x = Ops. dynamic_slice (a, [first .(indices)... ], [length .(indices)... ])
245
+ else
246
+ x = Ops. slice (a, [first .(indices)... ], [last .(indices)... ]; strides)
247
+ end
248
+
215
249
ddims = findall (indices) do idx
216
250
return idx isa Integer || idx isa TracedRNumber{<: Integer }
217
251
end
@@ -313,7 +347,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, index::CartesianIndex{N}) where
313
347
end
314
348
315
349
function _setindex_linear! (a:: TracedRArray{T,N} , v, indices:: AbstractArray ) where {T,N}
316
- if ! (indices isa Reactant. TracedType) && __contiguous_indices (vec (indices))
350
+ if ! (indices isa Reactant. TracedType) && TracedUtils . __contiguous_indices (vec (indices))
317
351
res = Ops. reshape (
318
352
Ops. dynamic_update_slice (
319
353
materialize_traced_array (vec (a)),
@@ -371,7 +405,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
371
405
use_scatter_setindex = true
372
406
break
373
407
end
374
- contiguous = __contiguous_indices (idxs)
408
+ contiguous = TracedUtils . __contiguous_indices (idxs)
375
409
if typeof (contiguous) <: Bool && ! contiguous
376
410
use_scatter_setindex = true
377
411
break
@@ -384,9 +418,44 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
384
418
if any (i -> unwrapped_eltype (i) <: Bool , indices)
385
419
error (" Boolean indexing with TracedRArrays isn't fully supported yet." )
386
420
end
387
- indices_list = map (Base. Fix1 (TracedUtils. promote_to, TracedRArray{Int,1 }), indices)
388
- indices_list = generate_index_list (indices_list... )
389
- res = Ops. scatter_setindex (a, indices_list, Ops. reshape (v, length (v)))
421
+
422
+ gather_dims = TracedUtils. indices_to_gather_dims (indices... )
423
+
424
+ v = Ops. convert (
425
+ TracedRArray{T,ndims (v)},
426
+ TracedUtils. promote_to (TracedRArray{unwrapped_eltype (v),ndims (v)}, v),
427
+ )
428
+
429
+ updates = Ops. transpose (v, invperm (gather_dims. permutation))
430
+ n_collapsed = length (gather_dims. collapsed_slice_dims)
431
+ updates_shape = Int64[
432
+ prod (size (updates)[1 : n_collapsed]), size (updates)[(n_collapsed + 1 ): end ]. ..
433
+ ]
434
+ updates = Ops. reshape (updates, updates_shape)
435
+
436
+ # simply set the 2nd block argument as a result
437
+ update_computation = MLIR. IR. Region ()
438
+ block = MLIR. IR. Block (
439
+ [Ops. mlir_type (TracedRNumber{T}), Ops. mlir_type (TracedRNumber{T})],
440
+ [MLIR. IR. Location (), MLIR. IR. Location ()],
441
+ )
442
+ return_op = MLIR. Dialects. stablehlo. return_ ([MLIR. IR. argument (block, 2 )])
443
+ MLIR. IR. rmfromparent! (return_op)
444
+ push! (block, return_op)
445
+ pushfirst! (update_computation, block)
446
+
447
+ res = Ops. scatter (
448
+ [a],
449
+ gather_dims. start_indices,
450
+ [updates];
451
+ update_computation,
452
+ update_window_dims= gather_dims. offset_dims,
453
+ inserted_window_dims= gather_dims. collapsed_slice_dims,
454
+ input_batching_dims= Int64[],
455
+ scatter_indices_batching_dims= Int64[],
456
+ scatter_dims_to_operand_dims= gather_dims. start_index_map,
457
+ index_vector_dim= gather_dims. index_vector_dim,
458
+ )[1 ]
390
459
set_mlir_data! (a, get_mlir_data (res))
391
460
return v
392
461
end
0 commit comments