Skip to content

Commit b395baf

Browse files
committed
update test_SpectralConvTranspose2d_vanilla_export to run in channel_last in TF and Keras
1 parent 1e71b37 commit b395baf

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ $ make test
3232

3333
This command will:
3434
- check your code with black PEP-8 formatter and flake8 linter.
35-
- run `unittest` on the `tests/` folder with different Python and TensorFlow versions.
35+
- run `pytest` on the `tests/` folder with different Python and TensorFlow versions.
3636

3737

3838
## Submitting your changes

tests/test_layers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,9 @@ def test_SpectralConvTranspose2d_vanilla_export():
14221422
data_format="channels_first",
14231423
input_shape=(3, 28, 28),
14241424
)
1425+
kwargs["input_shape"] = uft.to_framework_channel(kwargs["input_shape"])
1426+
if kwargs["input_shape"][-1] == kwargs["in_channels"]:
1427+
kwargs["data_format"] = "channels_last"
14251428

14261429
model = uft.generate_k_lip_model(
14271430
SpectralConvTranspose2d, kwargs, kwargs["input_shape"], 1.0

0 commit comments

Comments
 (0)