diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index 4ff79ad854..d0a6bfeb18 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -77,16 +77,17 @@ void DisPatchW4AFp8Gemm( max_tokens, stream) } else { - GEMM_SWITCH_FP16( - M, K, batch_size, token_padding_size, kBlockN, TailN, - weight, - input, - out, - weight_scale, - input_row_sum, - tokens, - max_tokens, - stream) + PD_THROW("Only supported dtype in ['FP16']."); + // GEMM_SWITCH_FP16( + // M, K, batch_size, token_padding_size, kBlockN, TailN, + // weight, + // input, + // out, + // weight_scale, + // input_row_sum, + // tokens, + // max_tokens, + // stream) } } diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index 87b06fa747..1798263a27 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -85,12 +85,12 @@ gemm_case = [ [8192, 3584, 8, 0], # eb45T ffn1 - [8192, 3584, 8, 2048], # eb45T ffn1 - [7168, 8192, 8, 0], # eb45T ffn2 - [7168, 8192, 8, 2048], # eb45T ffn2 + # [8192, 3584, 8, 2048], # eb45T ffn1 + # [7168, 8192, 8, 0], # eb45T ffn2 + # [7168, 8192, 8, 2048], # eb45T ffn2 ] -dtype = ["BF16", "FP16"] +dtype = ["BF16"] def get_cutlass_type(type):