-
Notifications
You must be signed in to change notification settings - Fork 45
[autoWS][FA] Add vectorization and fadd2_reduce #540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary: mark N_CTX constexpr Test Plan: Reviewers: Subscribers: Tasks: Tags:
240f4d1
to
ed86b25
Compare
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
else: | ||
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) | ||
qk = qk * qk_scale - m_ij[:, None] | ||
if VECT_MUL == 2 or VECT_MUL == 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to test the bits directly?
if VECT_MUL == 2 or VECT_MUL == 3: | |
if VECT_MUL & 2: |
p0, p1 = p.reshape([PM, 2, PN // 2]).permute(0, 2, 1).split() | ||
l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2) | ||
l_i0 = l_i0 * alpha + l_ij0 | ||
l_i1 = l_i1 * alpha + l_ij1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we sum them after doing the reduction so we can keep the same interface? (and keep the differences localised to this part of the program)
Would that also save a register?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean "l_ij = l_ij0 + l_ij1" then a single l_i = l_i * alpha + l_ij? That is a good point.
The advantage is removing one addition inside the loop vs. register pressure of l_i1.
I basically copied from Gluon. And this is what is implemented in the dp version as well.
We can potentially clean this up if the alternative is better.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: mark N_CTX constexpr
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: