@@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
7979 let mnemonic = "shard";
8080
8181 let parameters = (ins
82- AttrParameter<"::mlir::SymbolRefAttr ", "cluster placed">:$cluster,
82+ AttrParameter<"::mlir::FlatSymbolRefAttr ", "cluster placed">:$cluster,
8383 ArrayRefParameter<"MeshAxesAttr">:$split_axes,
8484 OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
8585 OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,7 +91,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
9191 The MeshSharding attribute could be used in the encoding of a
9292 `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
9393
94- 1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
94+ 1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
9595 cluster where the distributed tensor is placed. The symbol must resolve to a
9696 `mesh.cluster` operation.
9797
@@ -145,7 +145,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
145145 }];
146146
147147 let builders = [
148- AttrBuilder<(ins "SymbolRefAttr ":$cluster,
148+ AttrBuilder<(ins "FlatSymbolRefAttr ":$cluster,
149149 "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
150150 "ArrayRef<MeshAxis>": $partial_axes,
151151 "mesh::Partial": $partial_type), [{
@@ -156,7 +156,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
156156 return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
157157 partial_type);
158158 }]>,
159- AttrBuilder<(ins "SymbolRefAttr ":$cluster,
159+ AttrBuilder<(ins "FlatSymbolRefAttr ":$cluster,
160160 "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
161161 return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
162162 }]>
0 commit comments