Skip to content

Commit d908837

Browse files
committed
Format code
1 parent eff40f1 commit d908837

File tree

4 files changed

+33
-30
lines changed

4 files changed

+33
-30
lines changed

src/Compiler.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,9 @@ function compile(f, args; pipeline_options="", client=nothing)
410410

411411
res = :($sym.buffer)
412412
push!(linearized_args, res)
413-
413+
414414
respaths = ((p for p in arg.paths if p[1] != :args)...,)
415-
415+
416416
resarg = false
417417
for respath in respaths
418418
if respath[1] == :result
@@ -440,7 +440,7 @@ function compile(f, args; pipeline_options="", client=nothing)
440440
push!(resarg_syncs, usbuf)
441441
end
442442
end
443-
443+
444444
for (idx, result) in enumerate(linear_results)
445445
paths = ((p for p in result.paths if p[1] != :args)...,)
446446
for path in paths

src/TracedRArray.jl

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1616
end
1717
end
1818

19-
function Base.getindex(a::TracedRArray{T,0}) where T
19+
function Base.getindex(a::TracedRArray{T,0}) where {T}
2020
return a
2121
end
2222

@@ -75,13 +75,13 @@ end
7575
function Base.setindex!(
7676
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
7777
) where {T,N}
78-
indices = [(promote_to(TracedRArray{Int, 0}, i isa Colon ? 1 : first(i))-1).mlir_data for i in indices]
78+
indices = [
79+
(promote_to(TracedRArray{Int,0}, i isa Colon ? 1 : first(i)) - 1).mlir_data for
80+
i in indices
81+
]
7982
v = promote_to(TracedRArray{T,N}, v)
8083
res = MLIR.IR.result(
81-
MLIR.Dialects.stablehlo.dynamic_update_slice(
82-
a.mlir_data, v.mlir_data, indices
83-
),
84-
1,
84+
MLIR.Dialects.stablehlo.dynamic_update_slice(a.mlir_data, v.mlir_data, indices), 1
8585
)
8686
a.mlir_data = res
8787
return v
@@ -162,7 +162,7 @@ end
162162

163163
function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
164164
if isa(rhs, TracedRArray)
165-
if typeof(rhs) == TracedRArray{T, N}
165+
if typeof(rhs) == TracedRArray{T,N}
166166
return rhs
167167
end
168168
return TracedRArray{T,N}(
@@ -185,9 +185,12 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
185185
end
186186
T0 = eltype(rhs)
187187
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
188-
return promote_to(TracedRArray{T, N}, TracedRArray{T0,length(size(rhs))}(
189-
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs)
190-
))
188+
return promote_to(
189+
TracedRArray{T,N},
190+
TracedRArray{T0,length(size(rhs))}(
191+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs)
192+
),
193+
)
191194
end
192195

193196
function promote_to(lhs::TracedRArray{T,N}, rhs) where {T,N}
@@ -491,31 +494,32 @@ for (jlop, hloop, hlocomp, merge) in (
491494
function elem_apply(
492495
fn::typeof($jlop), @nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs)
493496
) where {T,N}
494-
elem_apply(fn, lhs, promote_to(lhs, rhs))
497+
return elem_apply(fn, lhs, promote_to(lhs, rhs))
495498
end
496499

497500
function elem_apply(
498501
::typeof($jlop), @nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,N})
499502
) where {T,N}
500-
elem_apply(fn, promote_to(rhs, lhs), rhs)
503+
return elem_apply(fn, promote_to(rhs, lhs), rhs)
501504
end
502505

503-
function $jlop(@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs)
504-
) where {T, N}
505-
$jlop(lhs, promote_to(lhs, rhs))
506+
function $jlop(
507+
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs)
508+
) where {T,N}
509+
return $jlop(lhs, promote_to(lhs, rhs))
506510
end
507511

508-
function $jlop(@nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,N})
509-
) where {T, N}
510-
$jlop(promote_to(rhs, lhs), rhs)
512+
function $jlop(
513+
@nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,N})
514+
) where {T,N}
515+
return $jlop(promote_to(rhs, lhs), rhs)
511516
end
512517
end
513-
518+
514519
if merge != nothing
515520
@eval begin
516521
function $jlop(
517-
@nospecialize(lhs::TracedRArray{T,N}),
518-
@nospecialize(rhs::TracedRArray{T,N})
522+
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N})
519523
) where {T,N}
520524
elems = elem_apply($jlop, lhs, rhs)
521525
if N == 0
@@ -528,10 +532,9 @@ for (jlop, hloop, hlocomp, merge) in (
528532
else
529533
@eval begin
530534
function $jlop(
531-
@nospecialize(lhs::TracedRArray{T,0}),
532-
@nospecialize(rhs::TracedRArray{T,0})
535+
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0})
533536
) where {T}
534-
elem_apply($jlop, lhs, rhs)
537+
return elem_apply($jlop, lhs, rhs)
535538
end
536539
end
537540
end

src/Tracing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function traced_type(::Type{T}, seen, mode) where {T<:Function}
4141
changed = false
4242
traced_fieldtypes = ntuple(Val(N)) do i
4343
next = traced_type(fieldtype(T, i), seen, mode)
44-
changed |= next != fieldtype(T, i)
44+
changed |= next != fieldtype(T, i)
4545
next
4646
end
4747

test/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,15 @@ end
246246
# get_view_compiled = Reactant.compile(get_view, (x_concrete,))
247247
end
248248

249-
tuple_byref(x) = (; a =(; b=x))
249+
tuple_byref(x) = (; a=(; b=x))
250250
tuple_byref2(x) = abs2.(x), tuple_byref2(x)
251251

252252
@testset "Tuple byref" begin
253253
x = Reactant.to_rarray([1.0 -2.0; -3.0 4.0])
254254
f1 = Reactant.compile(tuple_byref, (x,))
255255
r1 = f1(x)
256256
@test r1.a.b.data === x.data
257-
257+
258258
# TODO this seems to hang during compile
259259
# f2 = Reactant.compile(tuple_byref2, (x,))
260260
# r2 = f2(x)

0 commit comments

Comments
 (0)