-
Notifications
You must be signed in to change notification settings - Fork 69
[Sharktank] Fix _TEST_LAST_OP_DISPATCH for wrapped functions #2564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2564 +/- ##
=======================================
Coverage ? 77.57%
=======================================
Files ? 264
Lines ? 25154
Branches ? 0
=======================================
Hits ? 19512
Misses ? 5642
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
sogartar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess more often than not you want the underlying op and not to test the trivial replication mechanism.
I am curious how you debug with this. Is it that somewhere in
sharktank/sharktank/ops/_registry.py
Outdated
| selected_override, *results = trampoline(self, *args, **kwargs) | ||
| if _ENABLE_TEST_LAST_OP_DISPATCH: | ||
| global _TEST_LAST_OP_DISPATCH | ||
| _TEST_LAST_OP_DISPATCH = selected_override | ||
|
|
||
| if hasattr(selected_override, "_trivially_replicable_wrapper"): | ||
| # For trivially replicable wrappers, don't set _TEST_LAST_OP_DISPATCH | ||
| # the inner calls (which occured already)will set it to the actual op. | ||
| # NOTE: This assumes that all shards called the same op. | ||
| pass | ||
| else: | ||
| # For wrappers such as `transfer_n_pin`, we set _TEST_LAST_OP_DISPATCH to the original op (not the wrapper). | ||
| _TEST_LAST_OP_DISPATCH = getattr( | ||
| selected_override, "_unwrapped", selected_override | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess more often than not you want the underlying op and not to test the trivial replication mechanism.
I am curious how you debug with this. Is it that somewhere in the model code you would enable recording the last dispatch and then inspect what is the recorded function to see if it is the correct one? Do you call it again with the same arguments?
If it is about tracing the dispatches maybe we can log what is getting called here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually to log the outer calls first we would need to do that in the trampoline before the call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few existing tests using this, search for the _test_enable_last_op_dispatch.
It seems to be used for unit tests when we have complicated overrides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to check in tests what override gets selected maybe we should expose this.
E.g.
override = ops.my_op.get_override(*args)
The problem is that we have coupling of the dispatch mechanism and the actual execution of the selected override as the op itself may return NotImplemented or continue its execution to fulfil the request.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There may even be multiple ops that could match but will return NotImplemented based on arg combinations. It also won't work for trivially_replicable since we don't know what it calls on the shards until after it's done so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any sort of nested dispatching would change what is happening. Here we only have some sort of patch for 2 special cases, transfer_n_pin and trivially_replicable.
Is your goal to make the unsharded tests seamlessly work in the replicated case? Meaning to have the same test code.
These tests rely on logic that is supposed to be not part of the API, which is usually a bad practice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can record not jus the last dispatch, but append to a list. Then the tests can check if the override is in the traced list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would have to make assumptions about how far back in the list to look. A simple wrapper like transfer_n_pin is not a problem since we know exactly which override it's wrapping. The issue is trivially_replicable. I think we should refactor it so that it wraps each individual override rather than the op as a whole.
Signed-off-by: Alex Vasile <[email protected]>
Signed-off-by: Alex Vasile <[email protected]>
Signed-off-by: Alex Vasile <[email protected]>
The
transfer_n_pinwrapper is fairly straightforward, we simply have to plumb the underlying wrapped function through.The
trivially_replicablewrapper is messier. We don't know which version of the op it will dispatch on the shards until after it's done it. We don't update the last op dispatched and instead leave_TEST_LAST_OP_DISPATCHpointing at the ops used by the last shard.