Skip to content

Commit fa892b0

Browse files
authored
Updated the fma and fp8 format
1 parent 79e6b58 commit fa892b0

File tree

1 file changed

+80
-1
lines changed

1 file changed

+80
-1
lines changed

models/H100TC.m

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,83 @@
2323
% D: Result of the operation D = A * B + C computed under the
2424
% specified tensor core configuration.
2525

26-
D = B200TC(alpha, A, B, beta, C, informat, outformat);
26+
function D=H100TC(alpha, A, B, beta, C, informat, outformat)
27+
%
28+
% H100TC Compute GEMM with a model of a tensor core of the H100 GPU.
29+
%
30+
% This function evaluates the expression D = A * B + C using the
31+
% H200 TC numerical-feature-based model. The accumulation of block
32+
% products is performed using recursive summation.
33+
%
34+
% Inputs
35+
% A: Left matrix operand for the matrix multiplication A * B.
36+
% B: Right matrix operand for the matrix multiplication A * B.
37+
% C: Matrix added to the product A * B.
38+
% informat: a string specifying the format of A and B.
39+
% Supported input formats:
40+
% fp8-(e5m2,e4m3), fp16, binary16, half,
41+
% bf16, bfloat16, tensorfloat32, tf32.
42+
% outformat: a string specifying the numerical format for C and D.
43+
% Supported output formats:
44+
% fp32, single, binary32,
45+
% fp16, binary16, half.
46+
%
47+
% Output
48+
% D: Result of the operation D = A * B + C computed under the
49+
% specified tensor core configuration.
50+
51+
% Allowed formats
52+
allowedOutFormats = {'fp32', 'single', 'binary32',...
53+
'fp16', 'binary16', 'half'};
54+
allowedInFormats = {'fp8-e5m2','fp8-e4m3','e5m2','e4m3',...
55+
'fp16','binary16', 'half','bf16','bfloat16','tensorfloat32','tf32'};
56+
57+
if exist('informat', 'var')
58+
if (~ismember(lower(informat), allowedInFormats))
59+
error('The specified input format is not supported.');
60+
end
61+
informat=lower(informat);
62+
end
63+
if (exist('outformat', 'var'))
64+
if (~ismember(lower(outformat), allowedOutFormats))
65+
error('The specified output format is not supported.');
66+
end
67+
outformat=lower(outformat);
68+
end
69+
70+
% Default structures assuming fp16 in and fp32 output. See
71+
% Generic_TC_Model.m for the information.
72+
def_params.fma = 16; % Fused multiply-add (FMA) size
73+
def_params.neab = 2; % TC extra alignment bits
74+
def_params.frmode = 'rz'; % TC final rounding mode
75+
def_params.stkbitenabled = 0;
76+
def_params.inter_pattern=0;
77+
78+
% Set up the model according to the formats specified.
79+
if ismember(informat, {'fp16','half','binary16'})
80+
if exist('outformat', 'var')
81+
if ismember(outformat, {'fp16','binary16','half'})
82+
def_params.frmode='rne'; % TC final rounding mode
83+
end
84+
end
85+
elseif ismember(informat, {'tf32', 'tensorfloat32'})
86+
def_params.fma=8;
87+
elseif ismember(informat, {'fp8-e5m2','fp8-e4m3','e5m2','e4m3'})
88+
% FMA size is 16, but interleaved pattern is used to join two
89+
% 16-element vectors.
90+
def_params.fma = 32;
91+
def_params.inter_pattern=0;
92+
def_params.neab=-10;
93+
if exist('outformat', 'var')
94+
def_outopts.format = outformat;
95+
if ismember(outformat, {'fp16','binary16','half'})
96+
def_params.frmode='rne';
97+
end
98+
end
99+
end
100+
101+
102+
D = GEMM(alpha, A, B, beta, C, informat, outformat, def_params);
103+
104+
end
105+

0 commit comments

Comments
 (0)