@@ -66,6 +66,10 @@ and require expensive copies and synchronization each time and therefore should
66
66
return TracedRNumber {T} ((), res2)
67
67
end
68
68
69
+ function Base. getindex (a:: TracedRArray{T,0} ) where {T}
70
+ return TracedRNumber {T} ((), a. mlir_data)
71
+ end
72
+
69
73
function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
70
74
indices = [i isa Colon ? (1 : size (a, idx)) : i for (idx, i) in enumerate (indices)]
71
75
res = MLIR. IR. result (
@@ -222,7 +226,12 @@ function elem_apply(
222
226
end
223
227
224
228
function elem_apply (f, args:: Vararg{Any,Nargs} ) where {Nargs}
225
- all (iszero ∘ ndims, args) && return f (args... )
229
+ if all (iszero ∘ ndims, args)
230
+ scalar_args = map (args) do arg
231
+ return promote_to (TracedRNumber{eltype (arg)}, arg)
232
+ end
233
+ return f (scalar_args... )
234
+ end
226
235
227
236
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn (
228
237
f, args, (), string (f) * " _broadcast_scalar" , false ; toscalar= true
@@ -440,6 +449,12 @@ function Base.fill!(A::TracedRArray{T,N}, x) where {T,N}
440
449
return A
441
450
end
442
451
452
+ function Base. fill! (A:: TracedRArray{T,N} , x:: TracedRNumber{T2} ) where {T,N,T2}
453
+ bcast = broadcast_to_size (promote_to (TracedRNumber{T}, x), size (A))
454
+ A. mlir_data = bcast. mlir_data
455
+ return A
456
+ end
457
+
443
458
struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
444
459
445
460
AbstractReactantArrayStyle (:: Val{N} ) where {N} = AbstractReactantArrayStyle {N} ()
458
473
459
474
function Base. similar (
460
475
bc:: Broadcasted{AbstractReactantArrayStyle{N}} , :: Type{T} , dims
461
- ) where {T,N}
476
+ ) where {T<: ReactantPrimitives ,N}
477
+ @assert N isa Int
478
+ return TracedRArray {T,N} ((), nothing , map (length, dims))
479
+ end
480
+
481
+ function Base. similar (
482
+ bc:: Broadcasted{AbstractReactantArrayStyle{N}} , :: Type{<:TracedRNumber{T}} , dims
483
+ ) where {T<: ReactantPrimitives ,N}
462
484
@assert N isa Int
463
485
return TracedRArray {T,N} ((), nothing , map (length, dims))
464
486
end
@@ -536,7 +558,7 @@ function broadcast_to_size(arg::T, rsize) where {T<:Number}
536
558
end
537
559
538
560
function broadcast_to_size (arg:: TracedRNumber , rsize)
539
- rsize == () && return arg
561
+ length ( rsize) == 0 && return arg
540
562
mlirty = MLIR. IR. type (arg. mlir_data)
541
563
return TracedRArray {eltype(arg),length(rsize)} (
542
564
(),
0 commit comments