@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
6969 Sharding sourceSharding,
7070 TypedValue<ShapedType> sourceShard, GridOp grid,
7171 int64_t splitTensorAxis, GridAxis splitGridAxis) {
72- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
72+ TypedValue<ShapedType> targetShard =
7373 AllSliceOp::create (builder, sourceShard, grid,
7474 ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
75- .getResult ()) ;
75+ .getResult ();
7676 Sharding targetSharding = targetShardingInSplitLastAxis (
7777 builder.getContext (), sourceSharding, splitTensorAxis, splitGridAxis);
7878 return {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
204204 APInt (64 , splitTensorAxis));
205205 ShapedType targetShape =
206206 shardShapedType (sourceUnshardedShape, grid, targetSharding);
207- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
208- tensor::CastOp::create (builder, targetShape, allGatherResult)
209- .getResult ());
207+ TypedValue<ShapedType> targetShard =
208+ tensor::CastOp::create (builder, targetShape, allGatherResult).getResult ();
210209 return {targetShard, targetSharding};
211210}
212211
@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
336335 APInt (64 , targetTensorAxis), APInt (64 , sourceTensorAxis));
337336 ShapedType targetShape =
338337 shardShapedType (sourceUnshardedShape, grid, targetSharding);
339- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
340- tensor::CastOp::create (builder, targetShape, allToAllResult).getResult ()) ;
338+ TypedValue<ShapedType> targetShard =
339+ tensor::CastOp::create (builder, targetShape, allToAllResult).getResult ();
341340 return {targetShard, targetSharding};
342341}
343342
@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
510509 auto targetSharding = target.getSharding ();
511510 ImplicitLocOpBuilder implicitLocOpBuilder (target->getLoc (), builder);
512511 return reshard (implicitLocOpBuilder, grid, sourceSharding, targetSharding,
513- cast<TypedValue<ShapedType>>(source.getSrc ()),
514- sourceShardValue);
512+ source.getSrc (), sourceShardValue);
515513}
516514
517515TypedValue<ShapedType> reshard (OpBuilder &builder, ShardOp source,
0 commit comments