@@ -21,22 +21,6 @@ def __new__(cls, data: Tensor, v_to_p: list[int]):
2121 # (which is bad!)
2222 assert not data .requires_grad or not torch .is_grad_enabled ()
2323
24- # There is something very subtle going on here. In particular,
25- # suppose that elem is a view. Does all of the view metadata
26- # (sizes, strides, storages) get propagated correctly? Yes!
27- # Internally, the way _make_subclass works is it creates an
28- # alias (using Tensor.alias) of the original tensor, which
29- # means we replicate storage/strides, but with the Python object
30- # as an instance of your subclass. In other words,
31- # _make_subclass is the "easy" case of metadata propagation,
32- # because anything that alias() propagates, you will get in
33- # your subclass. It is _make_wrapper_subclass which is
34- # problematic...
35- #
36- # TODO: We need to think about how we want to turn this into
37- # official API. I am thinking that something that does the
38- # assert above and this call could be made into a utility function
39- # that is in the public API
4024 shape = [data .shape [i ] for i in v_to_p ]
4125 result = Tensor ._make_wrapper_subclass (cls , shape , dtype = data .dtype , device = data .device )
4226 result ._data = data # type: ignore
0 commit comments