Skip to content

Conversation

@sxu
Copy link
Contributor

@sxu sxu commented Apr 22, 2025

Differential Revision: D73444078

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/10355

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 40181b2 with merge base 22ba09e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73444078

sxu added a commit to sxu/executorch that referenced this pull request Apr 22, 2025
Summary: Pull Request resolved: pytorch#10355

Differential Revision: D73444078
@sxu sxu force-pushed the export-D73444078 branch from d266e75 to e8fa403 Compare April 22, 2025 17:05
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73444078

sxu added a commit to sxu/executorch that referenced this pull request Apr 22, 2025
Summary: Pull Request resolved: pytorch#10355

Differential Revision: D73444078
@sxu sxu force-pushed the export-D73444078 branch from e8fa403 to 785b463 Compare April 22, 2025 17:16
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73444078

sxu added a commit to sxu/executorch that referenced this pull request Apr 22, 2025
Summary: Pull Request resolved: pytorch#10355

Differential Revision: D73444078
@sxu sxu force-pushed the export-D73444078 branch from 785b463 to 2fa9f60 Compare April 22, 2025 17:35
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73444078

@sxu sxu requested a review from kimishpatel April 22, 2025 17:58
def call_operator(self, op, args, kwargs, meta):
from executorch.extension.llm.custom_ops import custom_ops # noqa

if op != torch.ops.aten.scaled_dot_product_attention.default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this op getting decomposed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to run this pass before to_edge, and avoid the decomposed version for perf reasons.

kT = self._transpose(k, meta)
vT = self._transpose(v, meta)

if mask is not None and mask.node.meta["val"].dtype == torch.bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put a todo here that custom sdpa once supports boolean mask, this wont be needed. tag me on the todo

(mask, 0.0, float("-inf")),
{},
meta,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also worth checking if the mask is > 2D than add appropriate squeeze ops while making sure first N - 2 dims are all 1

meta,
)

custom_sdpa = super().call_operator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to add option here that allows us to assume that the mask will be causal and thus we can just set mask =None and is_causal = True, can you do that and add corresponding test?

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nits and special handling for mask and is_causal requested

@kimishpatel
Copy link
Contributor

cc: @guangy10 @larryliu0820

@sxu sxu force-pushed the export-D73444078 branch from 2fa9f60 to 05eaf00 Compare April 24, 2025 17:12
sxu added a commit to sxu/executorch that referenced this pull request Apr 24, 2025
Summary: Pull Request resolved: pytorch#10355

Reviewed By: billmguo

Differential Revision: D73444078
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73444078

@sxu sxu requested a review from kimishpatel April 24, 2025 17:20
@sxu
Copy link
Contributor Author

sxu commented Apr 28, 2025

@kimishpatel can you take another look?

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes. Looks good. @guangy10 we should look into adopting this as well.

Summary: Pull Request resolved: pytorch#10355

Reviewed By: billmguo, kimishpatel

Differential Revision: D73444078
@sxu sxu force-pushed the export-D73444078 branch from 05eaf00 to 40181b2 Compare April 29, 2025 03:00
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73444078

@facebook-github-bot facebook-github-bot merged commit 7054b1f into pytorch:main Apr 29, 2025
84 of 86 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported topic: not user facing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants