Skip to content

Commit 112a017

Browse files
authored
Convolution default fixed, extensive tests (#321)
The default (gather/scatter) backend for convolution is fixed, with aggressive testing to confirm. --------- Signed-off-by: Christopher Horvath <[email protected]>
1 parent cba58d0 commit 112a017

14 files changed

+1863
-181
lines changed

fvdb/convolution_plan.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,16 @@ def from_grid_batch(
189189
kernel_size = to_Vec3i(kernel_size, value_constraint=ValueConstraint.POSITIVE)
190190
stride = to_Vec3i(stride, value_constraint=ValueConstraint.POSITIVE)
191191

192-
if not _vec_is_all(kernel_size, kernel_size[0].item()):
193-
raise NotImplementedError("Non-uniform kernel sizes are not currently supported")
192+
if not _vec_is_all(stride, 1):
193+
raise NotImplementedError("Strides not equal to 1 are not currently supported")
194+
194195
if not _vec_is_all(stride, stride[0].item()):
195196
raise NotImplementedError("Non-uniform strides are not currently supported")
196197

197198
backend = expert_config.get("backend", "default")
199+
if backend != "default":
200+
raise NotImplementedError("Non-default backends are not currently supported")
201+
198202
if backend in ["dense", "halo", "lggs"]:
199203
if target_grid is not None:
200204
raise ValueError("Target grid must be None for dense, halo, and lggs backends.")
@@ -267,12 +271,16 @@ def from_grid_batch_transposed(
267271
kernel_size = to_Vec3i(kernel_size, value_constraint=ValueConstraint.POSITIVE)
268272
stride = to_Vec3i(stride, value_constraint=ValueConstraint.POSITIVE)
269273

270-
if not _vec_is_all(kernel_size, kernel_size[0].item()):
271-
raise NotImplementedError("Non-uniform kernel sizes are not currently supported")
274+
if not _vec_is_all(stride, 1):
275+
raise NotImplementedError("Strides not equal to 1 are not currently supported")
276+
272277
if not _vec_is_all(stride, stride[0].item()):
273278
raise NotImplementedError("Non-uniform strides are not currently supported")
274279

275280
backend = expert_config.get("backend", "default")
281+
if backend != "default":
282+
raise NotImplementedError("Non-default backends are not currently supported")
283+
276284
if backend == "dense":
277285
if target_grid is not None:
278286
raise ValueError("Target grid must be None for dense backend, transposed.")
@@ -348,12 +356,16 @@ def from_grid(
348356
kernel_size = to_Vec3i(kernel_size, value_constraint=ValueConstraint.POSITIVE)
349357
stride = to_Vec3i(stride, value_constraint=ValueConstraint.POSITIVE)
350358

351-
if not _vec_is_all(kernel_size, kernel_size[0].item()):
352-
raise NotImplementedError("Non-uniform kernel sizes are not currently supported")
359+
if not _vec_is_all(stride, 1):
360+
raise NotImplementedError("Strides not equal to 1 are not currently supported")
361+
353362
if not _vec_is_all(stride, stride[0].item()):
354363
raise NotImplementedError("Non-uniform strides are not currently supported")
355364

356365
backend = expert_config.get("backend", "default")
366+
if backend != "default":
367+
raise NotImplementedError("Non-default backends are not currently supported")
368+
357369
if backend in ["dense", "halo", "lggs"]:
358370
if target_grid is not None:
359371
raise ValueError("Target grid must be None for dense, halo, and lggs backends.")
@@ -424,12 +436,16 @@ def from_grid_transposed(
424436
kernel_size = to_Vec3i(kernel_size, value_constraint=ValueConstraint.POSITIVE)
425437
stride = to_Vec3i(stride, value_constraint=ValueConstraint.POSITIVE)
426438

427-
if not _vec_is_all(kernel_size, kernel_size[0].item()):
428-
raise NotImplementedError("Non-uniform kernel sizes are not currently supported")
439+
if not _vec_is_all(stride, 1):
440+
raise NotImplementedError("Strides not equal to 1 are not currently supported")
441+
429442
if not _vec_is_all(stride, stride[0].item()):
430443
raise NotImplementedError("Non-uniform strides are not currently supported")
431444

432445
backend = expert_config.get("backend", "default")
446+
if backend != "default":
447+
raise NotImplementedError("Non-default backends are not currently supported")
448+
433449
if backend == "dense":
434450
if target_grid is not None:
435451
raise ValueError("Target grid must be None for dense backend, transposed.")

0 commit comments

Comments
 (0)