-
Notifications
You must be signed in to change notification settings - Fork 190
[Hipblaslt][CMS] CMS Support for TF32 128x192x32 NN #3677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: hipblaslt_common_cms_phase2
Are you sure you want to change the base?
[Hipblaslt][CMS] CMS Support for TF32 128x192x32 NN #3677
Conversation
ceca724 to
b3a28f2
Compare
sebvince
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The schedule looks good but I think you could easily gain perf by doing intermediate wait before CVTs.
This might actually help reduce the LDS stalls on LRA as well if you have budget to spread them out a bit.
Also, using 4x4x4_16b mfma for CVTs would clearly help reducing the overall time (maybe in a separate PR ;) )
| - Range: [[128], [192], [1], [64, 64, 256]] | ||
| - Range: [[128], [192], [1], [1,1,64]] | ||
| - Range: [[128], [192], [1], [32, 64, 256]] | ||
| - Range: [[4096], [6144], [1], [64, 64, 256]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we just want to keep small shapes in this test file (removing these 2 lines)
| syncTable = [ | ||
| -1, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Begininng of a iteration. Wait for prior local read.") , | ||
| 10, SWaitCnt(dscnt=1, vlcnt=-1, vscnt=-1, comment="Before PackA0. Wait for first all LRA0. Skip 1*LRB0") , | ||
| 17, SWaitCnt(dscnt=6, vlcnt=-1, vscnt=-1, comment="Before GRA. Wait for all prior LRA0 for GRA. Skip 6*LRB0") , |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be SWaitCnt(dscnt=5) according the trace as the 6th is done at the same mfma index and after this wait. Wasn't it caught by the validator ?
| kernel["UsePLRPack"] = True | ||
| syncTable = [ | ||
| -1, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Begininng of a iteration. Wait for prior local read.") , | ||
| 10, SWaitCnt(dscnt=1, vlcnt=-1, vscnt=-1, comment="Before PackA0. Wait for first all LRA0. Skip 1*LRB0") , |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wait is a bit conversative. The trace shows 100 clk wait on it. You could have intermediate waits as you don't need all LRA0 instructions for the first CVTs. I think you would need 6 ds_reads to complete the first 8 CVTs for example
| 17, SWaitCnt(dscnt=6, vlcnt=-1, vscnt=-1, comment="Before GRA. Wait for all prior LRA0 for GRA. Skip 6*LRB0") , | ||
| 17, SBarrier(comment="GRA") , | ||
| 20, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before PackB0. Wait for all prior LRB0.") , | ||
| 29, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before GRB. Wait for all prior LRB0.") , |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: not needed given you already have a SWaitCnt(dscnt=0) at 20
| 29, SBarrier(comment="GRB") , | ||
| 35, SWaitCnt(dscnt=-1, vlcnt=6, vscnt=-1, comment="Before LRB3. Wait for GRB from previous iter. Skip 4*GRA + 2*GRB") , | ||
| 35, SBarrier(comment="LRB") , | ||
| 44, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before PackB3. Wait for all LRB3 for PackB3.") , |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same remark. You could reduce this wait by doing intermediate SWaitCnt
| 20, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before PackB0. Wait for all prior LRB0.") , | ||
| 29, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before GRB. Wait for all prior LRB0.") , | ||
| 29, SBarrier(comment="GRB") , | ||
| 35, SWaitCnt(dscnt=-1, vlcnt=6, vscnt=-1, comment="Before LRB3. Wait for GRB from previous iter. Skip 4*GRA + 2*GRB") , |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually waits for both GRA and GRB of previous iteration right ? So it looks like you second SWaitCnt(vlcnt=10) at index 53 is not necessary ?
| 44, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before PackB3. Wait for all LRB3 for PackB3.") , | ||
| 53, SWaitCnt(dscnt=-1, vlcnt=10, vscnt=-1, comment="Before LRA3. Wait for GRA from previous iter. Skip 4*GRA + 6*GRB") , | ||
| 53, SBarrier(comment="LRA") , | ||
| 64, SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Before PackA3. Wait for all prior LRA3.") , |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Possible intermediate SWaitCnt :)
Motivation
Add CMS scheduling for TF32 192x256x32 NN.
Relevant PR: #3551
Technical Details
Test Plan
UseCustomMainLoopSchedule: [0, 1]to compare CMS on and off.Test Result
Tensile
~30% uplift was observed when running Tensile
hipblaslt-bench
No gain versus the default kernel
Custom_Cijk_Ailk_Bljk_S_MX_B_BIAS_HA_S_SAV_NTD_SK3_UserArgs_MT256x256x32_MI16x16x1Submission Checklist