-
Notifications
You must be signed in to change notification settings - Fork 190
[Hipblaslt][CMS] CMS Support for TF32 128x192x32 NN #3677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: hipblaslt_common_cms_phase2
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2269,8 +2269,45 @@ def _get_schedule_128x192x32_TF32(kernel, useLDSTr, TLDS): | |
| syncCode = [] | ||
| nglshift = nllshift = 0 # vmcnt shift for ngl and nll | ||
| if isNN(kernel) and not useLDSTr and TLDS==1: | ||
| # TODO: Add NN schedule in upcoming PR | ||
| return False, None | ||
| kernel["UsePLRPack"] = True | ||
| syncTable = [ | ||
| -1, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Begininng of a iteration. Wait for prior local read.") , | ||
| 10, SWaitCnt(dscnt=1, vlcnt=-1, vscnt=-1, comment="Before PackA0. Wait for first all LRA0. Skip 1*LRB0") , | ||
| 17, SWaitCnt(dscnt=6, vlcnt=-1, vscnt=-1, comment="Before GRA. Wait for all prior LRA0 for GRA. Skip 6*LRB0") , | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be SWaitCnt(dscnt=5) according the trace as the 6th is done at the same mfma index and after this wait. Wasn't it caught by the validator ? |
||
| 17, SBarrier(comment="GRA") , | ||
| 20, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before PackB0. Wait for all prior LRB0.") , | ||
| 29, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before GRB. Wait for all prior LRB0.") , | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: not needed given you already have a SWaitCnt(dscnt=0) at 20 |
||
| 29, SBarrier(comment="GRB") , | ||
| 35, SWaitCnt(dscnt=-1, vlcnt=6, vscnt=-1, comment="Before LRB3. Wait for GRB from previous iter. Skip 4*GRA + 2*GRB") , | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually waits for both GRA and GRB of previous iteration right ? So it looks like you second SWaitCnt(vlcnt=10) at index 53 is not necessary ? |
||
| 35, SBarrier(comment="LRB") , | ||
| 44, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before PackB3. Wait for all LRB3 for PackB3.") , | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same remark. You could reduce this wait by doing intermediate SWaitCnt |
||
| 53, SWaitCnt(dscnt=-1, vlcnt=10, vscnt=-1, comment="Before LRA3. Wait for GRA from previous iter. Skip 4*GRA + 6*GRB") , | ||
| 53, SBarrier(comment="LRA") , | ||
| 64, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before PackA3. Wait for all prior LRA3.") , | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. Possible intermediate SWaitCnt :) |
||
| ] | ||
| optSchedule = { | ||
| 'SYNC' : [syncTable[::2]], | ||
| 'GRIncA': [[0, 0, 1, 1, 2, 2, 3, 3, 4]], | ||
| 'GRIncB': [[4, 5, 5, 6, 6, 7, 7, 8, 8]], | ||
| 'LRA0' : [[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]], | ||
| 'LRB0' : [[8, 10, 12, 14, 16, 17], | ||
| [9, 11, 13, 15, 16, 18]], | ||
| 'PackA0': [create_range(10, 6, 17, 1, 8)], | ||
| 'PackB0': [create_range(20, 12, 34, 1, 6)], | ||
| 'GRA' : [[19, 20, 21, 22, 23, 24, 25, 26]], | ||
| 'GRB' : [[29, 30, 31, 32, 52, 53, 54, 55, 56, 57, 58, 59]], | ||
| 'LRB3' : [[35, 36, 37, 38, 39, 40]], | ||
| 'LRA3' : [[53, 53, 54, 54, 55, 55, 56, 56, 57, 57, 58, 58, 59, 59, 60, 61]], | ||
| 'PackB3': [create_range(44, 9, 54, 1, 8)], | ||
| 'PackA3': [create_range(64, 6, 71, 1, 8)], | ||
| 'LRSA' : [[28]], | ||
| 'LRSB' : [[28]], | ||
| 'LWSA' : [[60]], | ||
| 'LWSB' : [[61]], | ||
| 'LCC' : [[71, 71]], | ||
| } | ||
| syncCode = syncTable[1::2] | ||
| nglshift = nllshift = 10 | ||
| elif isTN(kernel) and not useLDSTr and TLDS==1: | ||
| kernel["UsePLRPack"] = True | ||
| syncTable = [ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -189,4 +189,32 @@ BenchmarkProblems: | |
| - Range: [[192], [256], [1], [64, 64, 256]] | ||
| - Range: [[192], [256], [1], [1,1,64]] | ||
| - Range: [[192], [256], [1], [32, 64, 256]] | ||
| - BiasTypeArgs: ['b'] | ||
| - BiasTypeArgs: ['b'] | ||
| - # BenchmarkProblemSizeGroup - Standard - All problem | ||
| InitialSolutionParameters: | ||
| BenchmarkCommonParameters: | ||
| - KernelLanguage: ["Assembly"] | ||
| ForkParameters: | ||
| - MatrixInstruction: | ||
| - [16, 16, 32, 1, 1, 4, 6, 2, 2] | ||
| - DepthU: [32] | ||
| - LocalReadVectorWidth: [4] | ||
| - ScheduleIterAlg: [3] | ||
| - DirectToLds: [1] | ||
| - PrefetchGlobalRead: [2] | ||
| - PrefetchLocalRead: [1] | ||
| - UseCustomMainLoopSchedule: [1] | ||
| - StreamK: [3] | ||
| - StaggerU: [0] | ||
| - ClusterLocalRead: [1] | ||
| - TransposeLDS: [1] | ||
| - LDSTrInst: [0] | ||
| BenchmarkJoinParameters: | ||
| BenchmarkFinalParameters: | ||
| - ProblemSizes: | ||
| - Range: [[128], [192], [1], [64, 64, 256]] | ||
| - Range: [[128], [192], [1], [1,1,64]] | ||
| - Range: [[128], [192], [1], [32, 64, 256]] | ||
| - Range: [[4096], [6144], [1], [64, 64, 256]] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we just want to keep small shapes in this test file (removing these 2 lines) |
||
| - Exact: [2048, 3072, 1, 8192] | ||
| - BiasTypeArgs: ['b'] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wait is a bit conversative. The trace shows 100 clk wait on it. You could have intermediate waits as you don't need all LRA0 instructions for the first CVTs. I think you would need 6 ds_reads to complete the first 8 CVTs for example