Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit e0e6001

Browse files
committed
Support conj bit, neg bit
1 parent b1e7a1b commit e0e6001

File tree

5 files changed

+42
-22
lines changed

5 files changed

+42
-22
lines changed

functorch/csrc/BatchedTensorImpl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
2424
, value_(std::move(value))
2525
, bdims_(std::move(bdims))
2626
{
27+
// TODO: I don't think this ctor gets used.
28+
TORCH_INTERNAL_ASSERT(false);
2729
TORCH_INTERNAL_ASSERT(value_.defined());
2830
set_storage_access_should_throw();
2931
set_has_contiguity_policy(HasContiguityPolicy::CustomBehavior);
@@ -147,10 +149,7 @@ const char* BatchedTensorImpl::tensorimpl_type_name() const {
147149
}
148150

149151
Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
150-
DispatchKeySet key_set;
151-
if (tensor.is_cuda()) {
152-
key_set = key_set.add(DispatchKey::CUDA);
153-
}
152+
DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor);
154153
auto* batched = maybeGetBatchedImpl(tensor);
155154
if (batched) {
156155
auto requested_level = bdims.back().level();

functorch/csrc/BatchedTensorImpl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,18 @@ TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
160160
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
161161
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
162162

163+
constexpr DispatchKeySet kKeysToPropagateToWrapper({
164+
DispatchKey::Negative,
165+
DispatchKey::Conjugate,
166+
DispatchKey::XLA,
167+
DispatchKey::CUDA,
168+
DispatchKey::CPU,
169+
});
170+
171+
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
172+
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
173+
return key_set & kKeysToPropagateToWrapper;
174+
}
175+
163176
}
164177
}

functorch/csrc/TensorWrapper.cpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,9 @@ void dumpTensorCout(const Tensor& tensor) {
6262
}
6363

6464
c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, bool should_be_alive) {
65-
// TODO: denylist non-cuda/cpu backends to avoid funny business
66-
DispatchKeySet key_set;
67-
if (tensor.is_cuda()) {
68-
key_set = key_set.add(DispatchKey::CUDA);
69-
key_set = key_set.add(DispatchKey::AutogradCUDA);
70-
} else {
71-
key_set = key_set.add(DispatchKey::CPU);
72-
key_set = key_set.add(DispatchKey::AutogradCPU);
73-
}
65+
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
66+
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
67+
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
7468
key_set = key_set.add(kGradWrapperKey);
7569
if (should_be_alive) {
7670
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, getLifeHandleForLevel(level));
@@ -85,15 +79,9 @@ Tensor makeTensorWrapper(const Tensor& tensor, int64_t level) {
8579
TORCH_INTERNAL_ASSERT(wrapped->level() < level);
8680
}
8781

88-
// TODO: denylist non-cuda/cpu backends to avoid funny business
89-
DispatchKeySet key_set;
90-
if (tensor.is_cuda()) {
91-
key_set = key_set.add(DispatchKey::CUDA);
92-
key_set = key_set.add(DispatchKey::AutogradCUDA);
93-
} else {
94-
key_set = key_set.add(DispatchKey::CPU);
95-
key_set = key_set.add(DispatchKey::AutogradCPU);
96-
}
82+
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
83+
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
84+
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
9785
key_set = key_set.add(kGradWrapperKey);
9886
auto life_handle = getLifeHandleForLevel(level);
9987
auto result = at::detail::make_tensor<TensorWrapper>(key_set, tensor, level, std::move(life_handle));

test/test_eager_transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,16 @@ def f(x):
200200
result, vjp_fn = vjp(f, torch.tensor(1.))
201201
vjp_fn(result)
202202

203+
def test_conj_bit(self):
204+
x = torch.tensor(1+1j)
205+
def foo(x):
206+
assert not x.is_conj()
207+
y = x.conj()
208+
assert y.is_conj()
209+
return y
210+
res = grad(foo)(x)
211+
self.assertEqual(res, torch.ones_like(res))
212+
203213
def test_composed_with_autograd(self, device):
204214
x = torch.randn([], requires_grad=True, device=device)
205215

test/test_vmap.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2618,6 +2618,16 @@ def test_one_hot(self):
26182618
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(F.one_hot, args, {}):
26192619
self.assertEqual(loop_out, batched_out)
26202620

2621+
def test_conj_bit(self):
2622+
x = torch.tensor([1+1j, 2+1j])
2623+
def foo(x):
2624+
assert not x.is_conj()
2625+
y = x.conj()
2626+
assert y.is_conj()
2627+
return y
2628+
res = vmap(foo)(x)
2629+
self.assertEqual(res, x.conj())
2630+
26212631
def test_mode_key(self):
26222632
def vmap_f(x):
26232633
return x + torch.randn(())

0 commit comments

Comments
 (0)