|
23 | 23 | % D: Result of the operation D = A * B + C computed under the |
24 | 24 | % specified tensor core configuration. |
25 | 25 |
|
26 | | -D = B200TC(alpha, A, B, beta, C, informat, outformat); |
| 26 | +function D=H100TC(alpha, A, B, beta, C, informat, outformat) |
| 27 | +% |
| 28 | +% H100TC Compute GEMM with a model of a tensor core of the H100 GPU. |
| 29 | +% |
| 30 | +% This function evaluates the expression D = A * B + C using the |
| 31 | +% H200 TC numerical-feature-based model. The accumulation of block |
| 32 | +% products is performed using recursive summation. |
| 33 | +% |
| 34 | +% Inputs |
| 35 | +% A: Left matrix operand for the matrix multiplication A * B. |
| 36 | +% B: Right matrix operand for the matrix multiplication A * B. |
| 37 | +% C: Matrix added to the product A * B. |
| 38 | +% informat: a string specifying the format of A and B. |
| 39 | +% Supported input formats: |
| 40 | +% fp8-(e5m2,e4m3), fp16, binary16, half, |
| 41 | +% bf16, bfloat16, tensorfloat32, tf32. |
| 42 | +% outformat: a string specifying the numerical format for C and D. |
| 43 | +% Supported output formats: |
| 44 | +% fp32, single, binary32, |
| 45 | +% fp16, binary16, half. |
| 46 | +% |
| 47 | +% Output |
| 48 | +% D: Result of the operation D = A * B + C computed under the |
| 49 | +% specified tensor core configuration. |
| 50 | + |
| 51 | +% Allowed formats |
| 52 | +allowedOutFormats = {'fp32', 'single', 'binary32',... |
| 53 | + 'fp16', 'binary16', 'half'}; |
| 54 | +allowedInFormats = {'fp8-e5m2','fp8-e4m3','e5m2','e4m3',... |
| 55 | + 'fp16','binary16', 'half','bf16','bfloat16','tensorfloat32','tf32'}; |
| 56 | + |
| 57 | +if exist('informat', 'var') |
| 58 | + if (~ismember(lower(informat), allowedInFormats)) |
| 59 | + error('The specified input format is not supported.'); |
| 60 | + end |
| 61 | + informat=lower(informat); |
| 62 | +end |
| 63 | +if (exist('outformat', 'var')) |
| 64 | + if (~ismember(lower(outformat), allowedOutFormats)) |
| 65 | + error('The specified output format is not supported.'); |
| 66 | + end |
| 67 | + outformat=lower(outformat); |
| 68 | +end |
| 69 | + |
| 70 | +% Default structures assuming fp16 in and fp32 output. See |
| 71 | +% Generic_TC_Model.m for the information. |
| 72 | +def_params.fma = 16; % Fused multiply-add (FMA) size |
| 73 | +def_params.neab = 2; % TC extra alignment bits |
| 74 | +def_params.frmode = 'rz'; % TC final rounding mode |
| 75 | +def_params.stkbitenabled = 0; |
| 76 | +def_params.inter_pattern=0; |
| 77 | + |
| 78 | +% Set up the model according to the formats specified. |
| 79 | +if ismember(informat, {'fp16','half','binary16'}) |
| 80 | + if exist('outformat', 'var') |
| 81 | + if ismember(outformat, {'fp16','binary16','half'}) |
| 82 | + def_params.frmode='rne'; % TC final rounding mode |
| 83 | + end |
| 84 | + end |
| 85 | +elseif ismember(informat, {'tf32', 'tensorfloat32'}) |
| 86 | + def_params.fma=8; |
| 87 | +elseif ismember(informat, {'fp8-e5m2','fp8-e4m3','e5m2','e4m3'}) |
| 88 | + % FMA size is 16, but interleaved pattern is used to join two |
| 89 | + % 16-element vectors. |
| 90 | + def_params.fma = 32; |
| 91 | + def_params.inter_pattern=0; |
| 92 | + def_params.neab=-10; |
| 93 | + if exist('outformat', 'var') |
| 94 | + def_outopts.format = outformat; |
| 95 | + if ismember(outformat, {'fp16','binary16','half'}) |
| 96 | + def_params.frmode='rne'; |
| 97 | + end |
| 98 | + end |
| 99 | +end |
| 100 | + |
| 101 | + |
| 102 | + D = GEMM(alpha, A, B, beta, C, informat, outformat, def_params); |
| 103 | + |
| 104 | +end |
| 105 | + |
0 commit comments