Skip to content

Commit 5731c0b

Browse files
authored
use row major when building attributes (EnzymeAD#307)
* use row major when building attributes * format * opt for d=1 * reproducable test * workaround for 0 dim array * transpose padding * 2,N' -> N,2 * make_causal_mask
1 parent 954b33c commit 5731c0b

File tree

3 files changed

+71
-25
lines changed

3 files changed

+71
-25
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function NNlib.conv!(
109109
#! format: on
110110

111111
padding = Reactant.MLIR.IR.DenseElementsAttribute(
112-
reshape(collect(padding), (num_spatial_dims, 2))
112+
reshape(collect(padding), (2, num_spatial_dims))'
113113
)
114114
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))
115115

@@ -163,7 +163,7 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
163163
end
164164

165165
padding = Reactant.MLIR.IR.DenseElementsAttribute(
166-
reshape([padding..., 0, 0, 0, 0], (N, 2))
166+
reshape([padding..., 0, 0, 0, 0], (2, N))'
167167
)
168168

169169
output_shape = (output_spatial_shapes..., size(x, N - 1), size(x, N))
@@ -306,7 +306,7 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2)
306306
len = size(x, dims)
307307
# directly generating booleans were causing an incorrect constant attribute generation
308308
# but the optimized IR removes the type case so we are probably ok
309-
mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len)))'))
309+
mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len)))))
310310
return Reactant.promote_to(
311311
TracedRArray{Bool,2},
312312
TracedRArray{Int,2}(

src/mlir/IR/Attribute.jl

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,10 @@ function Base.fill(::Core.Type{Attribute}, value, shape)
492492
return Base.fill(value, shaped_type)
493493
end
494494

495+
to_row_major(x) = permutedims(x, ndims(x):-1:1)
496+
to_row_major(x::AbstractVector) = x
497+
to_row_major(x::AbstractArray{T,0}) where {T} = x
498+
495499
"""
496500
DenseElementsAttribute(array::AbstractArray)
497501
@@ -501,66 +505,86 @@ function DenseElementsAttribute(values::AbstractArray{Bool})
501505
shaped_type = TensorType(size(values), Type(Bool))
502506
return Attribute(
503507
API.mlirDenseElementsAttrBoolGet(
504-
shaped_type, length(values), AbstractArray{Cint}(values)
508+
shaped_type, length(values), AbstractArray{Cint}(to_row_major(values))
505509
),
506510
)
507511
end
508512

509513
function DenseElementsAttribute(values::AbstractArray{UInt8})
510514
shaped_type = TensorType(size(values), Type(UInt8))
511-
return Attribute(API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), values))
515+
return Attribute(
516+
API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), to_row_major(values))
517+
)
512518
end
513519

514520
function DenseElementsAttribute(values::AbstractArray{Int8})
515521
shaped_type = TensorType(size(values), Type(Int8))
516-
return Attribute(API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), values))
522+
return Attribute(
523+
API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), to_row_major(values))
524+
)
517525
end
518526

519527
function DenseElementsAttribute(values::AbstractArray{UInt16})
520528
shaped_type = TensorType(size(values), Type(UInt16))
521529
return Attribute(
522-
API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), values)
530+
API.mlirDenseElementsAttrUInt16Get(
531+
shaped_type, length(values), to_row_major(values)
532+
),
523533
)
524534
end
525535

526536
function DenseElementsAttribute(values::AbstractArray{Int16})
527537
shaped_type = TensorType(size(values), Type(Int16))
528-
return Attribute(API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), values))
538+
return Attribute(
539+
API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), to_row_major(values))
540+
)
529541
end
530542

531543
function DenseElementsAttribute(values::AbstractArray{UInt32})
532544
shaped_type = TensorType(size(values), Type(UInt32))
533545
return Attribute(
534-
API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), values)
546+
API.mlirDenseElementsAttrUInt32Get(
547+
shaped_type, length(values), to_row_major(values)
548+
),
535549
)
536550
end
537551

538552
function DenseElementsAttribute(values::AbstractArray{Int32})
539553
shaped_type = TensorType(size(values), Type(Int32))
540-
return Attribute(API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), values))
554+
return Attribute(
555+
API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), to_row_major(values))
556+
)
541557
end
542558

543559
function DenseElementsAttribute(values::AbstractArray{UInt64})
544560
shaped_type = TensorType(size(values), Type(UInt64))
545561
return Attribute(
546-
API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), values)
562+
API.mlirDenseElementsAttrUInt64Get(
563+
shaped_type, length(values), to_row_major(values)
564+
),
547565
)
548566
end
549567

550568
function DenseElementsAttribute(values::AbstractArray{Int64})
551569
shaped_type = TensorType(size(values), Type(Int64))
552-
return Attribute(API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), values))
570+
return Attribute(
571+
API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), to_row_major(values))
572+
)
553573
end
554574

555575
function DenseElementsAttribute(values::AbstractArray{Float32})
556576
shaped_type = TensorType(size(values), Type(Float32))
557-
return Attribute(API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), values))
577+
return Attribute(
578+
API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), to_row_major(values))
579+
)
558580
end
559581

560582
function DenseElementsAttribute(values::AbstractArray{Float64})
561583
shaped_type = TensorType(size(values), Type(Float64))
562584
return Attribute(
563-
API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), values)
585+
API.mlirDenseElementsAttrDoubleGet(
586+
shaped_type, length(values), to_row_major(values)
587+
),
564588
)
565589
end
566590

@@ -569,16 +593,17 @@ end
569593
function DenseElementsAttribute(values::AbstractArray{Float16})
570594
shaped_type = TensorType(size(values), Type(Float16))
571595
return Attribute(
572-
API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), values)
596+
API.mlirDenseElementsAttrFloat16Get(
597+
shaped_type, length(values), to_row_major(values)
598+
),
573599
)
574600
end
575601

576602
function DenseElementsAttribute(values::AbstractArray{<:Complex})
577603
shaped_type = TensorType(size(values), Type(eltype(values)))
578-
# TODO: row major
579604
return Attribute(
580605
API.mlirDenseElementsAttrRawBufferGet(
581-
shaped_type, length(values) * Base.elsize(values), values
606+
shaped_type, length(values) * Base.elsize(values), to_row_major(values)
582607
),
583608
)
584609
end
@@ -592,7 +617,9 @@ function DenseElementsAttribute(values::AbstractArray{String})
592617
# TODO may fail because `Type(String)` is not defined
593618
shaped_type = TensorType(size(values), Type(String))
594619
return Attribute(
595-
API.mlirDenseElementsAttrStringGet(shaped_type, length(values), values)
620+
API.mlirDenseElementsAttrStringGet(
621+
shaped_type, length(values), to_row_major(values)
622+
),
596623
)
597624
end
598625

@@ -663,25 +690,29 @@ function DenseArrayAttribute end
663690

664691
@llvmversioned min = v"16" DenseArrayAttribute(
665692
values::AbstractArray{Bool}; context::Context=context()
666-
) = Attribute(API.mlirDenseBoolArrayGet(context, length(values), values))
693+
) = Attribute(
694+
API.mlirDenseBoolArrayGet(
695+
context, length(values), AbstractArray{Cint}(to_row_major(values))
696+
),
697+
)
667698
@llvmversioned min = v"16" DenseArrayAttribute(
668699
values::AbstractArray{Int8}; context::Context=context()
669-
) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), values))
700+
) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values)))
670701
@llvmversioned min = v"16" DenseArrayAttribute(
671702
values::AbstractArray{Int16}; context::Context=context()
672-
) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), values))
703+
) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), to_row_major(values)))
673704
@llvmversioned min = v"16" DenseArrayAttribute(
674705
values::AbstractArray{Int32}; context::Context=context()
675-
) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), values))
706+
) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), to_row_major(values)))
676707
@llvmversioned min = v"16" DenseArrayAttribute(
677708
values::AbstractArray{Int64}; context::Context=context()
678-
) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), values))
709+
) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), to_row_major(values)))
679710
@llvmversioned min = v"16" DenseArrayAttribute(
680711
values::AbstractArray{Float32}; context::Context=context()
681-
) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), values))
712+
) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), to_row_major(values)))
682713
@llvmversioned min = v"16" DenseArrayAttribute(
683714
values::AbstractArray{Float64}; context::Context=context()
684-
) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), values))
715+
) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), to_row_major(values)))
685716

686717
@llvmversioned min = v"16" Attribute(values::AbstractArray) = DenseArrayAttribute(values)
687718

test/basic.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,18 @@ end
625625
@test y == x
626626
end
627627
end
628+
629+
function f_row_major(x)
630+
y = [1 2; 3 4; 5 6]
631+
if x isa Reactant.TracedRArray
632+
y = Reactant.promote_to(Reactant.TracedRArray{eltype(x),2}, y)
633+
end
634+
return x .+ y
635+
end
636+
637+
@testset "array attributes: row major" begin
638+
x = zeros(Int, 3, 2)
639+
x_ra = Reactant.to_rarray(x)
640+
641+
@test @jit(f_row_major(x_ra)) f_row_major(x)
642+
end

0 commit comments

Comments
 (0)