@@ -83,6 +83,11 @@ size_t dispatchGemmClusterShapeSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const
83
83
_2SM>(A, B, alpha, D, m, n, k, b, gemmConfig,
84
84
workspacePtr, workspaceBytes, stream);
85
85
break ;
86
+ case ClusterShape::ClusterShape_1x4x1:
87
+ return genericFp8GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_1, _4, _1>,
88
+ _1SM>(A, B, alpha, D, m, n, k, b, gemmConfig,
89
+ workspacePtr, workspaceBytes, stream);
90
+ break ;
86
91
default :
87
92
throw std::runtime_error (" invalid config for fp8 gemm" );
88
93
break ;
@@ -205,9 +210,8 @@ std::vector<CutlassGemmConfig> CutlassFp8GemmRunner<T>::getConfigs() const {
205
210
};
206
211
207
212
std::vector<ClusterShape> clusterShapes = {
208
- ClusterShape::ClusterShape_1x1x1,
209
- ClusterShape::ClusterShape_1x2x1,
210
- ClusterShape::ClusterShape_2x1x1,
213
+ ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1,
214
+ ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1,
211
215
ClusterShape::ClusterShape_2x2x1,
212
216
};
213
217
for (auto const & tile_config : tilesSm100) {
0 commit comments