@@ -492,6 +492,10 @@ function Base.fill(::Core.Type{Attribute}, value, shape)
492492 return Base. fill (value, shaped_type)
493493end
494494
495+ to_row_major (x) = permutedims (x, ndims (x): - 1 : 1 )
496+ to_row_major (x:: AbstractVector ) = x
497+ to_row_major (x:: AbstractArray{T,0} ) where {T} = x
498+
495499"""
496500 DenseElementsAttribute(array::AbstractArray)
497501
@@ -501,66 +505,86 @@ function DenseElementsAttribute(values::AbstractArray{Bool})
501505 shaped_type = TensorType (size (values), Type (Bool))
502506 return Attribute (
503507 API. mlirDenseElementsAttrBoolGet (
504- shaped_type, length (values), AbstractArray {Cint} (values)
508+ shaped_type, length (values), AbstractArray {Cint} (to_row_major ( values) )
505509 ),
506510 )
507511end
508512
509513function DenseElementsAttribute (values:: AbstractArray{UInt8} )
510514 shaped_type = TensorType (size (values), Type (UInt8))
511- return Attribute (API. mlirDenseElementsAttrUInt8Get (shaped_type, length (values), values))
515+ return Attribute (
516+ API. mlirDenseElementsAttrUInt8Get (shaped_type, length (values), to_row_major (values))
517+ )
512518end
513519
514520function DenseElementsAttribute (values:: AbstractArray{Int8} )
515521 shaped_type = TensorType (size (values), Type (Int8))
516- return Attribute (API. mlirDenseElementsAttrInt8Get (shaped_type, length (values), values))
522+ return Attribute (
523+ API. mlirDenseElementsAttrInt8Get (shaped_type, length (values), to_row_major (values))
524+ )
517525end
518526
519527function DenseElementsAttribute (values:: AbstractArray{UInt16} )
520528 shaped_type = TensorType (size (values), Type (UInt16))
521529 return Attribute (
522- API. mlirDenseElementsAttrUInt16Get (shaped_type, length (values), values)
530+ API. mlirDenseElementsAttrUInt16Get (
531+ shaped_type, length (values), to_row_major (values)
532+ ),
523533 )
524534end
525535
526536function DenseElementsAttribute (values:: AbstractArray{Int16} )
527537 shaped_type = TensorType (size (values), Type (Int16))
528- return Attribute (API. mlirDenseElementsAttrInt16Get (shaped_type, length (values), values))
538+ return Attribute (
539+ API. mlirDenseElementsAttrInt16Get (shaped_type, length (values), to_row_major (values))
540+ )
529541end
530542
531543function DenseElementsAttribute (values:: AbstractArray{UInt32} )
532544 shaped_type = TensorType (size (values), Type (UInt32))
533545 return Attribute (
534- API. mlirDenseElementsAttrUInt32Get (shaped_type, length (values), values)
546+ API. mlirDenseElementsAttrUInt32Get (
547+ shaped_type, length (values), to_row_major (values)
548+ ),
535549 )
536550end
537551
538552function DenseElementsAttribute (values:: AbstractArray{Int32} )
539553 shaped_type = TensorType (size (values), Type (Int32))
540- return Attribute (API. mlirDenseElementsAttrInt32Get (shaped_type, length (values), values))
554+ return Attribute (
555+ API. mlirDenseElementsAttrInt32Get (shaped_type, length (values), to_row_major (values))
556+ )
541557end
542558
543559function DenseElementsAttribute (values:: AbstractArray{UInt64} )
544560 shaped_type = TensorType (size (values), Type (UInt64))
545561 return Attribute (
546- API. mlirDenseElementsAttrUInt64Get (shaped_type, length (values), values)
562+ API. mlirDenseElementsAttrUInt64Get (
563+ shaped_type, length (values), to_row_major (values)
564+ ),
547565 )
548566end
549567
550568function DenseElementsAttribute (values:: AbstractArray{Int64} )
551569 shaped_type = TensorType (size (values), Type (Int64))
552- return Attribute (API. mlirDenseElementsAttrInt64Get (shaped_type, length (values), values))
570+ return Attribute (
571+ API. mlirDenseElementsAttrInt64Get (shaped_type, length (values), to_row_major (values))
572+ )
553573end
554574
555575function DenseElementsAttribute (values:: AbstractArray{Float32} )
556576 shaped_type = TensorType (size (values), Type (Float32))
557- return Attribute (API. mlirDenseElementsAttrFloatGet (shaped_type, length (values), values))
577+ return Attribute (
578+ API. mlirDenseElementsAttrFloatGet (shaped_type, length (values), to_row_major (values))
579+ )
558580end
559581
560582function DenseElementsAttribute (values:: AbstractArray{Float64} )
561583 shaped_type = TensorType (size (values), Type (Float64))
562584 return Attribute (
563- API. mlirDenseElementsAttrDoubleGet (shaped_type, length (values), values)
585+ API. mlirDenseElementsAttrDoubleGet (
586+ shaped_type, length (values), to_row_major (values)
587+ ),
564588 )
565589end
566590
@@ -569,16 +593,17 @@ end
569593function DenseElementsAttribute (values:: AbstractArray{Float16} )
570594 shaped_type = TensorType (size (values), Type (Float16))
571595 return Attribute (
572- API. mlirDenseElementsAttrFloat16Get (shaped_type, length (values), values)
596+ API. mlirDenseElementsAttrFloat16Get (
597+ shaped_type, length (values), to_row_major (values)
598+ ),
573599 )
574600end
575601
576602function DenseElementsAttribute (values:: AbstractArray{<:Complex} )
577603 shaped_type = TensorType (size (values), Type (eltype (values)))
578- # TODO : row major
579604 return Attribute (
580605 API. mlirDenseElementsAttrRawBufferGet (
581- shaped_type, length (values) * Base. elsize (values), values
606+ shaped_type, length (values) * Base. elsize (values), to_row_major ( values)
582607 ),
583608 )
584609end
@@ -592,7 +617,9 @@ function DenseElementsAttribute(values::AbstractArray{String})
592617 # TODO may fail because `Type(String)` is not defined
593618 shaped_type = TensorType (size (values), Type (String))
594619 return Attribute (
595- API. mlirDenseElementsAttrStringGet (shaped_type, length (values), values)
620+ API. mlirDenseElementsAttrStringGet (
621+ shaped_type, length (values), to_row_major (values)
622+ ),
596623 )
597624end
598625
@@ -663,25 +690,29 @@ function DenseArrayAttribute end
663690
664691@llvmversioned min = v " 16" DenseArrayAttribute (
665692 values:: AbstractArray{Bool} ; context:: Context = context ()
666- ) = Attribute (API. mlirDenseBoolArrayGet (context, length (values), values))
693+ ) = Attribute (
694+ API. mlirDenseBoolArrayGet (
695+ context, length (values), AbstractArray {Cint} (to_row_major (values))
696+ ),
697+ )
667698@llvmversioned min = v " 16" DenseArrayAttribute (
668699 values:: AbstractArray{Int8} ; context:: Context = context ()
669- ) = Attribute (API. mlirDenseI8ArrayGet (context, length (values), values))
700+ ) = Attribute (API. mlirDenseI8ArrayGet (context, length (values), to_row_major ( values) ))
670701@llvmversioned min = v " 16" DenseArrayAttribute (
671702 values:: AbstractArray{Int16} ; context:: Context = context ()
672- ) = Attribute (API. mlirDenseI16ArrayGet (context, length (values), values))
703+ ) = Attribute (API. mlirDenseI16ArrayGet (context, length (values), to_row_major ( values) ))
673704@llvmversioned min = v " 16" DenseArrayAttribute (
674705 values:: AbstractArray{Int32} ; context:: Context = context ()
675- ) = Attribute (API. mlirDenseI32ArrayGet (context, length (values), values))
706+ ) = Attribute (API. mlirDenseI32ArrayGet (context, length (values), to_row_major ( values) ))
676707@llvmversioned min = v " 16" DenseArrayAttribute (
677708 values:: AbstractArray{Int64} ; context:: Context = context ()
678- ) = Attribute (API. mlirDenseI64ArrayGet (context, length (values), values))
709+ ) = Attribute (API. mlirDenseI64ArrayGet (context, length (values), to_row_major ( values) ))
679710@llvmversioned min = v " 16" DenseArrayAttribute (
680711 values:: AbstractArray{Float32} ; context:: Context = context ()
681- ) = Attribute (API. mlirDenseF32ArrayGet (context, length (values), values))
712+ ) = Attribute (API. mlirDenseF32ArrayGet (context, length (values), to_row_major ( values) ))
682713@llvmversioned min = v " 16" DenseArrayAttribute (
683714 values:: AbstractArray{Float64} ; context:: Context = context ()
684- ) = Attribute (API. mlirDenseF64ArrayGet (context, length (values), values))
715+ ) = Attribute (API. mlirDenseF64ArrayGet (context, length (values), to_row_major ( values) ))
685716
686717@llvmversioned min = v " 16" Attribute (values:: AbstractArray ) = DenseArrayAttribute (values)
687718
0 commit comments