From 3b91d3a686653b3a71a534c7b53b2c9246ce3a33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sat, 27 Sep 2025 23:45:44 +0100 Subject: [PATCH] Allow passing custom size to Array-like `make_tracer` methods This allows creating traced arrays with shape different than the original data, without having to add extra `reshape` instructions. --- src/Tracing.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 04fdc97255..6d79c517fb 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1169,6 +1169,7 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(device = nothing), @nospecialize(client = nothing), + @nospecialize(size = size(prev)), kwargs..., ) where {T,N} if mode == TracedToTypes @@ -1177,7 +1178,7 @@ Base.@nospecializeinfer function make_tracer( mode == ArrayToConcrete && return ConcretePJRTArray(prev; sharding, device, client) mode != ConcreteToTraced && throw("Cannot trace concrete") haskey(seen, prev) && return seen[prev]::TracedRArray{T,N} - res = TracedRArray{T,N}((path,), nothing, size(prev)) + res = TracedRArray{T,N}((path,), nothing, size) seen[prev] = res return res end @@ -1190,6 +1191,7 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(device = nothing), @nospecialize(client = nothing), + @nospecialize(size = size(prev)), kwargs..., ) where {T,N} if mode == TracedToTypes @@ -1198,7 +1200,7 @@ Base.@nospecializeinfer function make_tracer( mode == ArrayToConcrete && return ConcreteIFRTArray(prev; sharding, device, client) mode != ConcreteToTraced && throw("Cannot trace concrete") haskey(seen, prev) && return seen[prev]::TracedRArray{T,N} - res = TracedRArray{T,N}((path,), nothing, size(prev)) + res = TracedRArray{T,N}((path,), nothing, size) seen[prev] = res return res end @@ -1254,6 +1256,7 @@ Base.@nospecializeinfer function make_tracer( tobatch=nothing, @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), + @nospecialize(size = size(prev)), kwargs..., ) where {T,N} if mode == ConcreteToTraced @@ -1287,7 +1290,7 @@ Base.@nospecializeinfer function make_tracer( elseif tobatch !== nothing error("This should not happen...") else - TracedRArray{T,N}((path,), prev.mlir_data, size(prev)) + TracedRArray{T,N}((path,), prev.mlir_data, size) end seen[prev] = res return res @@ -1298,7 +1301,7 @@ Base.@nospecializeinfer function make_tracer( haskey(seen, prev) && return seen[prev]::ConcretePJRTArray{T,N} if !Sharding.is_sharded(sharding) res = ConcretePJRTArray{T,N,1,Sharding.NoShardInfo}( - (XLA.PJRT.AsyncEmptyBuffer,), size(prev), Sharding.NoShardInfo() + (XLA.PJRT.AsyncEmptyBuffer,), size, Sharding.NoShardInfo() ) else error("TODO: implement sharding") @@ -1309,7 +1312,7 @@ Base.@nospecializeinfer function make_tracer( haskey(seen, prev) && return seen[prev]::ConcreteIFRTArray{T,N} if !Sharding.is_sharded(sharding) res = ConcreteIFRTArray{T,N,Sharding.NoShardInfo}( - XLA.IFRT.AsyncEmptyArray, size(prev), Sharding.NoShardInfo() + XLA.IFRT.AsyncEmptyArray, size, Sharding.NoShardInfo() ) else error("TODO: implement sharding") @@ -1556,6 +1559,7 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(runtime = nothing), @nospecialize(device = nothing), @nospecialize(client = nothing), + @nospecialize(size = size(prev)), kwargs..., ) RT = Core.Typeof(prev) @@ -1604,7 +1608,7 @@ Base.@nospecializeinfer function make_tracer( return nothing end TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime) - newa = Array{TT,ndims(RT)}(undef, size(prev)) + newa = Array{TT,ndims(RT)}(undef, size) seen[prev] = newa same = true for I in eachindex(prev)