1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import os
15+
1416import numpy
1517import paddle
1618import paddle .nn .functional as F
@@ -26,8 +28,13 @@ def swiglu(x, y=None):
2628 return F .silu (x ) * y
2729
2830
31+ USE_DS_GEMM = os .getenv ("USE_DS_GEMM" , "False" ).lower () == "true"
32+
2933try :
30- from paddle .incubate .fp8 import deep_gemm
34+ if USE_DS_GEMM :
35+ import deep_gemm
36+ else :
37+ from paddle .incubate .fp8 import deep_gemm
3138except :
3239 pass
3340
@@ -82,9 +89,16 @@ def padding_and_quant_input(tensor):
8289 return tensor_fp8 , tensor_scale , tensor_t_fp8 , tensor_t_scale
8390
8491 @staticmethod
85- def kitchen_gemm (
86- x_fp8 , x_scale , w_fp8 , w_scale , is_a_1d_scaled = True , is_b_1d_scaled = True , out = None , rtn_dtype = paddle .bfloat16
92+ def kitchen_fp8_gemm (
93+ x_fp8 , x_scale , w_fp8 , w_scale , is_a_1d_scaled , is_b_1d_scaled , out = None , rtn_dtype = paddle .bfloat16
8794 ):
95+ if USE_DS_GEMM :
96+ if out is None :
97+ out = paddle .zeros ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], rtn_dtype )
98+ if numpy .prod (x_fp8 .shape ) != 0 and numpy .prod (w_fp8 .shape ) != 0 :
99+ deep_gemm .wgrad_gemm_fp8_fp8_fp32_nt ((x_fp8 , x_scale ), (w_fp8 , w_scale ), out , num_sms = 112 )
100+ return out
101+
88102 if out is not None :
89103 accumulate = True
90104 out_dtype = out .dtype
0 commit comments