Skip to content

Commit 4757cf9

Browse files
committed
refactor: use a union type for traced types
1 parent 45158bb commit 4757cf9

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ function compile_mlir!(mod, f, args; optimize=true)
290290
preserved_args = Tuple{TracedRArray,Int}[]
291291
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
292292
nresults = MLIR.IR.Value[]
293-
linear_results2 = Union{TracedRArray,TracedRNumber}[]
293+
linear_results2 = TracedTypes[]
294294
for (i, op) in enumerate(results)
295295
if !MLIR.IR.is_block_arg(op)
296296
push!(nresults, op)

src/Reactant.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,13 @@ include("mlir/MLIR.jl")
8383
include("XLA.jl")
8484
include("Interpreter.jl")
8585
include("utils.jl")
86+
8687
include("ConcreteRArray.jl")
8788
include("TracedRNumber.jl")
8889
include("TracedRArray.jl")
90+
91+
const TracedTypes = Union{TracedRArray,TracedRNumber}
92+
8993
include("Tracing.jl")
9094
include("Compiler.jl")
9195

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
4444
)
4545
end
4646

47-
linear_args = Union{TracedRArray,TracedRNumber}[]
47+
linear_args = TracedTypes[]
4848
for (k, v) in seen_args
4949
if !(v isa TracedRArray) && !(v isa TracedRNumber)
5050
continue
@@ -127,7 +127,7 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
127127
)
128128
end
129129

130-
linear_results = Union{TracedRArray,TracedRNumber}[]
130+
linear_results = TracedTypes[]
131131

132132
for (k, v) in seen_results
133133
if !(v isa TracedRArray) && !(v isa TracedRNumber)

0 commit comments

Comments
 (0)