@@ -973,86 +973,86 @@ struct ConvertXeTileToXeGPUPass // convert XeTile to XeGPU
973
973
memSpace);
974
974
});
975
975
976
- typeConverter.addConversion ([&](xetile::TileType type)
977
- -> xegpu::TensorDescType {
978
- auto context = type.getContext ();
979
- auto scatterAttr = type.getScatterAttr ();
980
- bool isScattered = scatterAttr ? scatterAttr.getValue () : false ;
981
-
982
- // by default the targetTy is the element type, except for SLM cases,
983
- // where the data will be treated as 32-bit type implicitly.
984
- Type targetTy = type.getElementType ();
985
-
986
- xegpu::SGMapAttr sgMap = nullptr ;
987
- if (auto attr = type.getSgMap ()) {
988
- auto layout =
989
- llvm::to_vector_of< uint32_t >( attr.getWiLayout ().asArrayRef () );
990
- auto data = llvm::to_vector_of< uint32_t >(attr. getWiData (). asArrayRef () );
991
- sgMap = xegpu::SGMapAttr::get (context, layout, data);
992
- }
993
-
994
- auto memSpaceAttr = convertMemorySpace (type. getMemorySpace ());
995
- auto memSpace =
996
- memSpaceAttr ? memSpaceAttr. getValue () : xegpu::MemorySpace::Global;
997
-
998
- Attribute encoding ;
999
- llvm::SmallVector< int64_t > shape;
1000
- if (isScattered) {
1001
- // Scattered tile is lowered to scattered tensor_desc with chunk
1002
- // size 1. It supports both global memory and shared memory. while
1003
- // scattered tile can support 2D shape, scattered tensor_desc only
1004
- // support 1D shape.
1005
- auto chunkSizeAttr = IntegerAttr::get (IntegerType::get (context, 64 ), 1 );
1006
- auto msA = memSpaceAttr
1007
- ? memSpaceAttr
1008
- : xegpu::MemorySpaceAttr::get (context, memSpace);
1009
-
1010
- encoding =
1011
- xegpu::ScatterTensorDescAttr::get (context, msA, chunkSizeAttr);
1012
- shape.push_back (type.getNumElements ());
1013
- } else if (memSpace == xegpu::MemorySpace::Global) {
1014
- // Blocked tile on global memory is lowered to blocked tensor_desc
1015
- // with the same shape.
1016
- auto arrayLenAttr = type.getArrayLength ();
1017
- auto boundaryCheckAttr = BoolAttr::get (context, true );
1018
- encoding = xegpu::BlockTensorDescAttr::get (
1019
- context, memSpaceAttr, arrayLenAttr, boundaryCheckAttr);
1020
- shape = llvm::to_vector (type.getShape ());
1021
- } else {
1022
- // for TileType created for SLM access, it will be converted into:
1023
- // 1. a 1D block tensor_desc if it is for row-major access
1024
- // 2. a scattered tensor_desc if it is for col-major access.
1025
- auto elemBits = type.getElementType ().getIntOrFloatBitWidth ();
1026
- auto vnniFactor = std::max<int >(32 / elemBits, 1 );
1027
-
1028
- // SLM access only supports 32-bit or 64-bit data type, so convert
1029
- // the type if original element type is less than 32-bit.
1030
- if (elemBits < 32 ) {
1031
- targetTy = type.getElementType ().isInteger ()
1032
- ? (Type)IntegerType::get (context, 32 )
1033
- : (Type)Float32Type::get (context);
1034
- }
1035
-
1036
- if (isColMajorOrder (type.getOrder ())) {
1037
- // For access with col-major order
1038
- auto chunkSize = type.getShape ()[0 ] / vnniFactor;
1039
- auto chunkSizeAttr =
1040
- IntegerAttr::get (IntegerType::get (context, 64 ), chunkSize);
1041
- encoding = xegpu::ScatterTensorDescAttr::get (context, memSpaceAttr,
1042
- chunkSizeAttr);
1043
- shape = {type.getShape ()[1 ], chunkSize};
1044
- } else {
1045
- // For access with row-major order
1046
- auto vecSize = type.getNumElements () / vnniFactor;
1047
- encoding = xegpu::BlockTensorDescAttr::get (
1048
- context, memSpaceAttr, nullptr /* array_len*/ ,
1049
- nullptr /* boundary_check*/ );
1050
- shape.push_back (vecSize);
1051
- }
1052
- }
1053
- return xegpu::TensorDescType::get (context, shape, targetTy, encoding,
1054
- sgMap);
1055
- });
976
+ typeConverter.addConversion (
977
+ [&](xetile::TileType type) -> xegpu::TensorDescType {
978
+ auto context = type.getContext ();
979
+ auto scatterAttr = type.getScatterAttr ();
980
+ bool isScattered = scatterAttr ? scatterAttr.getValue () : false ;
981
+
982
+ // by default the targetTy is the element type, except for SLM cases,
983
+ // where the data will be treated as 32-bit type implicitly.
984
+ Type targetTy = type.getElementType ();
985
+
986
+ xegpu::LayoutAttr sgMap = nullptr ;
987
+ if (auto attr = type.getSgMap ()) {
988
+ auto layout = attr. getWiLayout (). asArrayRef ();
989
+ auto data = attr.getWiData ().asArrayRef ();
990
+ sgMap = xegpu::LayoutAttr::get (context, layout, data );
991
+ }
992
+
993
+ auto memSpaceAttr = convertMemorySpace (type. getMemorySpace ());
994
+ auto memSpace = memSpaceAttr ? memSpaceAttr. getValue ()
995
+ : xegpu::MemorySpace::Global;
996
+
997
+ Attribute encoding;
998
+ llvm::SmallVector< int64_t > shape ;
999
+ if (isScattered) {
1000
+ // Scattered tile is lowered to scattered tensor_desc with chunk
1001
+ // size 1. It supports both global memory and shared memory. while
1002
+ // scattered tile can support 2D shape, scattered tensor_desc only
1003
+ // support 1D shape.
1004
+ auto chunkSizeAttr =
1005
+ IntegerAttr::get (IntegerType::get (context, 64 ), 1 );
1006
+ auto msA = memSpaceAttr
1007
+ ? memSpaceAttr
1008
+ : xegpu::MemorySpaceAttr::get (context, memSpace);
1009
+
1010
+ encoding =
1011
+ xegpu::ScatterTensorDescAttr::get (context, msA, chunkSizeAttr);
1012
+ shape.push_back (type.getNumElements ());
1013
+ } else if (memSpace == xegpu::MemorySpace::Global) {
1014
+ // Blocked tile on global memory is lowered to blocked tensor_desc
1015
+ // with the same shape.
1016
+ auto arrayLenAttr = type.getArrayLength ();
1017
+ auto boundaryCheckAttr = BoolAttr::get (context, true );
1018
+ encoding = xegpu::BlockTensorDescAttr::get (
1019
+ context, memSpaceAttr, arrayLenAttr, boundaryCheckAttr);
1020
+ shape = llvm::to_vector (type.getShape ());
1021
+ } else {
1022
+ // for TileType created for SLM access, it will be converted into:
1023
+ // 1. a 1D block tensor_desc if it is for row-major access
1024
+ // 2. a scattered tensor_desc if it is for col-major access.
1025
+ auto elemBits = type.getElementType ().getIntOrFloatBitWidth ();
1026
+ auto vnniFactor = std::max<int >(32 / elemBits, 1 );
1027
+
1028
+ // SLM access only supports 32-bit or 64-bit data type, so convert
1029
+ // the type if original element type is less than 32-bit.
1030
+ if (elemBits < 32 ) {
1031
+ targetTy = type.getElementType ().isInteger ()
1032
+ ? (Type)IntegerType::get (context, 32 )
1033
+ : (Type)Float32Type::get (context);
1034
+ }
1035
+
1036
+ if (isColMajorOrder (type.getOrder ())) {
1037
+ // For access with col-major order
1038
+ auto chunkSize = type.getShape ()[0 ] / vnniFactor;
1039
+ auto chunkSizeAttr =
1040
+ IntegerAttr::get (IntegerType::get (context, 64 ), chunkSize);
1041
+ encoding = xegpu::ScatterTensorDescAttr::get (
1042
+ context, memSpaceAttr, chunkSizeAttr);
1043
+ shape = {type.getShape ()[1 ], chunkSize};
1044
+ } else {
1045
+ // For access with row-major order
1046
+ auto vecSize = type.getNumElements () / vnniFactor;
1047
+ encoding = xegpu::BlockTensorDescAttr::get (
1048
+ context, memSpaceAttr, nullptr /* array_len*/ ,
1049
+ nullptr /* boundary_check*/ );
1050
+ shape.push_back (vecSize);
1051
+ }
1052
+ }
1053
+ return xegpu::TensorDescType::get (context, shape, targetTy, encoding,
1054
+ sgMap);
1055
+ });
1056
1056
1057
1057
auto materializeWithCast = [&](OpBuilder &builder, Type type,
1058
1058
ValueRange inputs, Location loc) -> Value {
0 commit comments