Skip to content

Commit 8bb1152

Browse files
stmcgovernpytorchmergebot
authored andcommitted
[DTensor] Fix convolution ops with bias=None in torch.compile (pytorch#167258)
Fixes pytorch#167091 DTensor convolution operations crashed when bias=None was passed with torch.compile because the code assumed bias always exists, but the ATen schema defines it as optional (Tensor?). This fix: - Handles None bias_spec in convolution_rules (forward pass) - Handles None bias_shape_opt in convolution_backward_rules - Returns None for grad_bias_spec when bias is None - Extends None output handling to indices 0,1,2 in _sharding_prop.py Added 3 regression tests covering compile mode, backward pass, and nn.Conv2d module API with bias=False. This is related to issue pytorch#159959 and this PR pytorch#165438 that resolves it, overlapping in the` _sharding_prop.py` change. Pull Request resolved: pytorch#167258 Approved by: https://github.com/XilunWu
1 parent bbf39ca commit 8bb1152

File tree

3 files changed

+137
-23
lines changed

3 files changed

+137
-23
lines changed

test/distributed/tensor/test_convolution_ops.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,98 @@ def test_conv3d(self):
230230
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
231231
self.assertEqual(out_dt, out)
232232

233+
@with_comms
234+
def test_conv2d_no_bias_compile(self):
235+
"""Test Conv2d with bias=False in compile mode (Issue #167091)
236+
237+
Regression test: Previously this would fail during torch.compile
238+
tracing with AssertionError when bias_spec was None.
239+
"""
240+
device_mesh = self.build_device_mesh()
241+
242+
def conv_fn(x, w):
243+
return F.conv2d(x, w, bias=None, padding=1)
244+
245+
compiled_fn = torch.compile(conv_fn)
246+
247+
# Create tensors
248+
x = torch.randn(1, 4, 5, 5, device=self.device_type)
249+
w = torch.randn(8, 4, 3, 3, device=self.device_type)
250+
251+
# Distribute tensors
252+
x_dt = distribute_tensor(x, device_mesh, [Replicate()])
253+
w_dt = distribute_tensor(w, device_mesh, [Replicate()])
254+
255+
# Test eager mode for comparison
256+
result_eager = conv_fn(x_dt, w_dt)
257+
258+
# Test compiled mode - this should not crash
259+
result_compiled = compiled_fn(x_dt, w_dt)
260+
261+
# Verify shape is correct (the key regression test)
262+
self.assertEqual(result_compiled.shape, torch.Size([1, 8, 5, 5]))
263+
264+
# Verify numerical correctness
265+
torch.testing.assert_close(result_compiled.to_local(), result_eager.to_local())
266+
267+
@with_comms
268+
def test_conv2d_no_bias_backward(self):
269+
"""Test Conv2d backward pass with bias=False (Issue #167091)
270+
271+
Regression test: Previously backward pass would fail when
272+
grad_bias_spec was None.
273+
"""
274+
device_mesh = self.build_device_mesh()
275+
276+
# Create tensors with requires_grad
277+
x = torch.randn(1, 4, 5, 5, device=self.device_type)
278+
w = torch.randn(8, 4, 3, 3, device=self.device_type, requires_grad=True)
279+
280+
# Distribute tensors
281+
x_dt = distribute_tensor(x, device_mesh, [Replicate()])
282+
w_dt = torch.nn.Parameter(distribute_tensor(w, device_mesh, [Replicate()]))
283+
284+
# Forward pass
285+
result = F.conv2d(x_dt, w_dt, bias=None, padding=1)
286+
287+
# Backward pass - this should not crash
288+
grad_output = torch.randn_like(result)
289+
result.backward(grad_output)
290+
291+
# Check weight gradient exists (the key regression test)
292+
self.assertIsNotNone(w_dt.grad)
293+
self.assertEqual(w_dt.grad.shape, torch.Size([8, 4, 3, 3]))
294+
295+
@with_comms
296+
def test_conv2d_module_no_bias(self):
297+
"""Test nn.Conv2d module with bias=False (Issue #167091)
298+
299+
Regression test: Ensures nn.Conv2d with bias=False works with DTensor.
300+
"""
301+
device_mesh = self.build_device_mesh()
302+
303+
# Create model with bias=False
304+
model = nn.Conv2d(4, 8, kernel_size=3, padding=1, bias=False).to(
305+
self.device_type
306+
)
307+
nn.init.ones_(model.weight)
308+
309+
# Distribute model
310+
model_dt = distribute_module(model, device_mesh, _conv_fn)
311+
312+
# Create input
313+
x = torch.randn(1, 4, 5, 5, device=self.device_type)
314+
x_dt = distribute_tensor(x, device_mesh, [Replicate()])
315+
316+
# Forward pass - this should not crash
317+
output_dt = model_dt(x_dt)
318+
319+
# Check outputs shape is correct
320+
self.assertEqual(output_dt.shape, torch.Size([1, 8, 5, 5]))
321+
322+
# Check that model.bias is None
323+
self.assertIsNone(model.bias)
324+
233325

234326
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
235327
DistConvolutionOpsTest,
@@ -238,6 +330,10 @@ def test_conv3d(self):
238330
"test_conv_backward_none_grad_inp",
239331
"test_depthwise_convolution",
240332
"test_downsampling_convolution",
333+
# New tests for Issue #167091 - use send/recv via tp_convolution
334+
"test_conv2d_no_bias_compile",
335+
"test_conv2d_no_bias_backward",
336+
"test_conv2d_module_no_bias",
241337
],
242338
)
243339

torch/distributed/tensor/_ops/_conv_ops.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
2626

2727
assert isinstance(input_spec, DTensorSpec)
2828
assert isinstance(weight_spec, DTensorSpec)
29-
assert isinstance(bias_spec, DTensorSpec)
29+
# bias_spec can be None (optional parameter in aten.convolution schema)
30+
if bias_spec is not None:
31+
assert isinstance(bias_spec, DTensorSpec)
3032
assert input_spec.tensor_meta is not None
3133
assert weight_spec.tensor_meta is not None
3234
in_shape = input_spec.tensor_meta.shape
3335
weight_shape = weight_spec.tensor_meta.shape
34-
assert isinstance(stride, list)
35-
assert isinstance(padding, list)
36-
assert isinstance(dilation, list)
37-
assert isinstance(weight_shape, torch.Size)
36+
assert isinstance(stride, list), f"stride must be list, got {type(stride)}"
37+
assert isinstance(padding, list), f"padding must be list, got {type(padding)}"
38+
assert isinstance(dilation, list), f"dilation must be list, got {type(dilation)}"
39+
# weight_shape might not be torch.Size in all cases (e.g., SymIntArrayRef during tracing)
40+
# so we don't assert its type, just use it
3841
out_conv_shape = [
3942
(d + 2 * padding[i] - dilation[i] * (weight_shape[i + 1] - 1) - 1) // stride[i]
4043
+ 1
@@ -82,14 +85,21 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
8285
assert isinstance(grad_output_spec, DTensorSpec)
8386
assert isinstance(input_spec, DTensorSpec)
8487
assert isinstance(weight_spec, DTensorSpec)
85-
assert isinstance(bias_shape_opt, list)
88+
# bias_shape_opt can be None (optional parameter in aten.convolution_backward schema)
89+
if bias_shape_opt is not None:
90+
assert isinstance(bias_shape_opt, list)
8691
assert input_spec.tensor_meta is not None
8792
weight_tensor_meta = weight_spec.tensor_meta
88-
bias_tensor_meta = TensorMeta(
89-
torch.Size(bias_shape_opt),
90-
(1,),
91-
input_spec.tensor_meta.dtype,
92-
)
93+
94+
# Only create bias_tensor_meta if bias_shape_opt is not None
95+
if bias_shape_opt is not None:
96+
bias_tensor_meta = TensorMeta(
97+
torch.Size(bias_shape_opt),
98+
(1,),
99+
input_spec.tensor_meta.dtype,
100+
)
101+
else:
102+
bias_tensor_meta = None
93103

94104
grad_input_spec = input_spec
95105
grad_weight_spec = DTensorSpec.from_dim_map(
@@ -98,12 +108,18 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
98108
[0],
99109
tensor_meta=weight_tensor_meta,
100110
)
101-
grad_bias_spec = DTensorSpec.from_dim_map(
102-
input_spec.mesh,
103-
[-1],
104-
[0],
105-
tensor_meta=bias_tensor_meta,
106-
)
111+
112+
# Only create grad_bias_spec if we have bias_tensor_meta
113+
if bias_tensor_meta is not None:
114+
grad_bias_spec = DTensorSpec.from_dim_map(
115+
input_spec.mesh,
116+
[-1],
117+
[0],
118+
tensor_meta=bias_tensor_meta,
119+
)
120+
else:
121+
grad_bias_spec = None
122+
107123
# TODO: actually the output_mask is not respected here, we should
108124
# set the corresponding spec to `None` if the output_mask is not `False`
109125
# for a certain output Tensor. This also applies to the conv handler

torch/distributed/tensor/_sharding_prop.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,16 @@ def _create_output_spec_with_new_tensor_meta(
275275
output_tensor_meta_i = output_tensor_meta[i]
276276
if not isinstance(output_tensor_meta_i, TensorMeta):
277277
# NOTE: aten.convolution_backward.default is an exception and it
278-
# needs extra handling because the first Tensor in the output
279-
# tuple can be `None` if the input Tensor to convolution op has
280-
# `requires_grad=False` (e.g. convolution layer is the first
281-
# layer in the model). We explicitly allow its corresponding
282-
# TensorMeta to be `None`.
278+
# needs extra handling because any Tensor in the output tuple
279+
# can be `None` depending on the output_mask parameter. This can
280+
# occur during double backpropagation or when certain gradients
281+
# are not needed (e.g., grad_input when input has requires_grad=False,
282+
# grad_weight/grad_bias when weight/bias have requires_grad=False,
283+
# or grad_bias when bias is None). We explicitly allow the
284+
# corresponding TensorMeta to be `None`.
283285
if (
284286
op == aten.convolution_backward.default
285-
and i == 0
287+
and i in (0, 1, 2)
286288
and output_tensor_meta_i is None
287289
):
288290
assert isinstance(output_specs, list)

0 commit comments

Comments
 (0)