Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 239111d

Browse files
authored
Random assortment of batch rules and cleanups (#347)
Test Plan: - wait for tests
1 parent 5304c81 commit 239111d

File tree

6 files changed

+33
-23
lines changed

6 files changed

+33
-23
lines changed

functorch/csrc/BatchRulesBinaryOps.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
334334

335335
#define LOGICAL_COMPARISON_POINTWISE(op) \
336336
VMAP_SUPPORT(#op, \
337-
SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN(op)), &ATEN_FN(op)>));
337+
SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN(op)), &ATEN_FN(op)>)); \
338+
m.impl(#op "_", inplacePlumbing2< \
339+
DECLTYPE_AUTO(&binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor:: op ## _ >)>);
338340

339341
LOGICAL_COMPARISON_POINTWISE(logical_and);
340342
LOGICAL_COMPARISON_POINTWISE(logical_or);

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,14 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
227227
OP_DECOMPOSE(conv_transpose1d);
228228
OP_DECOMPOSE2(conv_transpose2d, input);
229229
OP_DECOMPOSE2(conv_transpose3d, input);
230+
OP_DECOMPOSE(conv1d);
231+
OP_DECOMPOSE(conv2d);
232+
OP_DECOMPOSE(conv3d);
233+
OP_DECOMPOSE2(conv1d, padding);
234+
OP_DECOMPOSE2(conv2d, padding);
235+
OP_DECOMPOSE2(conv3d, padding);
236+
OP_DECOMPOSE(_convolution_mode);
237+
OP_DECOMPOSE(type_as);
230238
DECOMPOSE_FUNCTIONAL(diag_embed);
231239
DECOMPOSE_FUNCTIONAL(block_diag);
232240
}

functorch/csrc/BatchRulesModules.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,20 @@ convolution_batch_rule(const Tensor& lhs, optional<int64_t> lhs_bdim, const Tens
8787
return result;
8888
}
8989
}
90-
Tensor convNd_decomp(const Tensor &self, const Tensor &weight, const optional<Tensor>& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
91-
std::vector<int64_t> t(self.dim() - 2, 0);
92-
IntArrayRef out_padding(t);
93-
return at::convolution(self, weight, bias, stride, padding, dilation, false, out_padding, groups);
90+
91+
Tensor _convolution_decomp(
92+
const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_r_opt,
93+
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
94+
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
95+
bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
96+
// Ignore everything. If the user called this in the normal way,
97+
// then they should be fine.
98+
(void*) benchmark;
99+
(void*) deterministic;
100+
(void*) cudnn_enabled;
101+
(void*) allow_tf32;
102+
return at::convolution(
103+
input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_);
94104
}
95105

96106
// Tensor convNd_transpose_decomp(const Tensor &self, const Tensor &weight, const optional<Tensor>& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
@@ -517,13 +527,10 @@ struct CudnnGridSampleBackwardBatchRuleHelper {
517527

518528
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
519529
VMAP_SUPPORT("convolution", convolution_batch_rule);
520-
// m.impl("conv_transpose2d", convNd_transpose_decomp);
530+
m.impl("_convolution", _convolution_decomp);
521531
m.impl("mkldnn_convolution", mkldnn_convolution_decomp);
522532
m.impl("cudnn_convolution_backward", cudnn_convolution_backward_plumbing);
523533
m.impl("cudnn_convolution", cudnn_convolution_plumbing);
524-
m.impl("conv1d", convNd_decomp);
525-
m.impl("conv2d", convNd_decomp);
526-
m.impl("conv3d", convNd_decomp);
527534

528535
EXISTING_BDIM(im2col);
529536
EXISTING_BDIM(im2col_backward);

test/discover_coverage.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,13 @@ def print_coverage_info(th=100, nn=25):
359359
statuses = transpose_statuses(get_top_ops(th, nn), invert=True)
360360
top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(th, nn)
361361

362+
# testing problems
363+
exemptions = {
364+
'torch.nn.functional.dropout', # randomness
365+
}
366+
362367
# Allowed exemptions
363368
vmap_exemptions = {
364-
'torch.nn.functional.dropout', # randomness
365369
'torch.randn_like', # randomness
366370
'torch.allclose', # number output
367371
'torch.unique', # dynamic
@@ -374,6 +378,8 @@ def print_coverage_info(th=100, nn=25):
374378
remove_from_set(statuses['test_vmapvjp_has_batch_rule'], vmap_exemptions)
375379
remove_from_set(statuses['test_op_has_batch_rule'], vmap_exemptions)
376380
remove_from_set(statuses['test_vmapjvp'], vmap_exemptions)
381+
for test in tests:
382+
remove_from_set(statuses[test], exemptions)
377383

378384
print(f"total ops in set: {th + nn}")
379385
print(f"tested by OpInfo: {th + nn - len(top_ops_not_covered_by_opinfo)}")

test/test_ops.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,6 @@ def vjp_of_vjp(*args_and_cotangents):
475475
xfail('double', 'channels_last'),
476476
xfail('nn.functional.gaussian_nll_loss'),
477477
xfail('nn.functional.poisson_nll_loss'),
478-
skip('nn.functional.conv1d', device_type='cuda'),
479478
xfail('fft.rfft2'),
480479
xfail('lu'),
481480
skip('qr'), # Nondetermistic
@@ -694,10 +693,8 @@ def test_vmapjvpall(self, device, dtype, op):
694693
xfail('__getitem__'),
695694
xfail('cdist'),
696695
xfail('cholesky'),
697-
xfail('clamp', 'scalar'),
698696
xfail('complex'),
699697
xfail('copysign'),
700-
xfail('corrcoef'),
701698
xfail('cummax'),
702699
xfail('cummin'),
703700
xfail('cumprod'),
@@ -732,7 +729,6 @@ def test_vmapjvpall(self, device, dtype, op):
732729
xfail('linalg.slogdet'),
733730
xfail('linalg.solve'),
734731
xfail('linalg.tensorinv'),
735-
xfail('linalg.vector_norm'),
736732
xfail('logdet'),
737733
xfail('lu'),
738734
xfail('lu_solve'),
@@ -741,16 +737,11 @@ def test_vmapjvpall(self, device, dtype, op):
741737
xfail('masked_scatter'),
742738
xfail('masked_select'),
743739
xfail('matrix_exp'),
744-
xfail('max', 'reduction_no_dim'),
745-
xfail('median'),
746-
xfail('min', 'reduction_no_dim'),
747-
xfail('nanmedian'),
748740
xfail('nanquantile'),
749741
xfail('nn.functional.conv_transpose2d'),
750742
xfail('nn.functional.gelu'),
751743
xfail('nn.functional.pad', 'circular'),
752744
xfail('norm', 'fro'),
753-
xfail('norm', 'inf'),
754745
xfail('norm', 'nuc'),
755746
xfail('pinverse'),
756747
xfail('prod'),
@@ -785,8 +776,6 @@ def test_vmapjvpall(self, device, dtype, op):
785776
xfail('nn.functional.instance_norm'),
786777
xfail('nn.functional.poisson_nll_loss'),
787778
xfail('nn.functional.conv_transpose3d'),
788-
xfail('_masked.norm'),
789-
xfail('_masked.normalize'),
790779
xfail('nn.functional.bilinear'),
791780
xfail('nn.functional.prelu'),
792781
xfail('nn.functional.glu'),

test/test_vmap.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3221,7 +3221,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
32213221
xfail('nanmean'),
32223222
xfail('vstack'),
32233223
xfail('nn.functional.dropout'),
3224-
xfail('nn.functional.conv2d', ''),
32253224
xfail('nn.functional.batch_norm'),
32263225
xfail('resize_'),
32273226
xfail('view_as_complex'),
@@ -3253,7 +3252,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
32533252
xfail('short', 'channels_last'),
32543253
xfail('unique_consecutive'),
32553254
xfail('unique'),
3256-
xfail('nn.functional.conv1d'),
32573255
xfail('nn.functional.cosine_embedding_loss'),
32583256
xfail('nn.functional.ctc_loss'),
32593257
xfail('nn.functional.gaussian_nll_loss'),

0 commit comments

Comments
 (0)