|
5 | 5 | import cv2 |
6 | 6 | import numpy as np |
7 | 7 | import pytest |
| 8 | +import torch |
8 | 9 |
|
| 10 | +from tests.conftest import timed |
| 11 | +from tiatoolbox import logger, rcParam |
9 | 12 | from tiatoolbox.tools.registration.wsi_registration import ( |
10 | 13 | AffineWSITransformer, |
11 | 14 | DFBRegister, |
@@ -576,3 +579,70 @@ def test_affine_wsi_transformer(sample_ome_tiff: Path) -> None: |
576 | 579 | expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE) |
577 | 580 |
|
578 | 581 | assert np.sum(expected - output) == 0 |
| 582 | + |
| 583 | + |
| 584 | +def test_dfbr_feature_extractor_torch_compile(dfbr_features: Path) -> None: |
| 585 | + """Test DFBRFeatureExtractor with torch.compile functionality. |
| 586 | +
|
| 587 | + Args: |
| 588 | + dfbr_features (Path): Path to the expected features. |
| 589 | +
|
| 590 | + """ |
| 591 | + |
| 592 | + def _extract_features() -> tuple: |
| 593 | + dfbr = DFBRegister() |
| 594 | + fixed_img = np.repeat( |
| 595 | + np.expand_dims( |
| 596 | + np.repeat( |
| 597 | + np.expand_dims(np.arange(0, 64, 1, dtype=np.uint8), axis=1), |
| 598 | + 64, |
| 599 | + axis=1, |
| 600 | + ), |
| 601 | + axis=2, |
| 602 | + ), |
| 603 | + 3, |
| 604 | + axis=2, |
| 605 | + ) |
| 606 | + output = dfbr.extract_features(fixed_img, fixed_img) |
| 607 | + pool3_feat = output["block3_pool"][0, :].detach().numpy() |
| 608 | + pool4_feat = output["block4_pool"][0, :].detach().numpy() |
| 609 | + pool5_feat = output["block5_pool"][0, :].detach().numpy() |
| 610 | + |
| 611 | + return pool3_feat, pool4_feat, pool5_feat |
| 612 | + |
| 613 | + torch_compile_mode = rcParam["torch_compile_mode"] |
| 614 | + torch._dynamo.reset() |
| 615 | + rcParam["torch_compile_mode"] = "default" |
| 616 | + (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features) |
| 617 | + _pool3_feat, _pool4_feat, _pool5_feat = np.load( |
| 618 | + str(dfbr_features), |
| 619 | + allow_pickle=True, |
| 620 | + ) |
| 621 | + assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4 |
| 622 | + assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4 |
| 623 | + assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4 |
| 624 | + logger.info("torch.compile default mode: %s", compile_time) |
| 625 | + torch._dynamo.reset() |
| 626 | + rcParam["torch_compile_mode"] = "reduce-overhead" |
| 627 | + (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features) |
| 628 | + _pool3_feat, _pool4_feat, _pool5_feat = np.load( |
| 629 | + str(dfbr_features), |
| 630 | + allow_pickle=True, |
| 631 | + ) |
| 632 | + assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4 |
| 633 | + assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4 |
| 634 | + assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4 |
| 635 | + logger.info("torch.compile reduce-overhead mode: %s", compile_time) |
| 636 | + torch._dynamo.reset() |
| 637 | + rcParam["torch_compile_mode"] = "max-autotune" |
| 638 | + (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features) |
| 639 | + _pool3_feat, _pool4_feat, _pool5_feat = np.load( |
| 640 | + str(dfbr_features), |
| 641 | + allow_pickle=True, |
| 642 | + ) |
| 643 | + assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4 |
| 644 | + assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4 |
| 645 | + assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4 |
| 646 | + logger.info("torch.compile max-autotune mode: %s", compile_time) |
| 647 | + torch._dynamo.reset() |
| 648 | + rcParam["torch_compile_mode"] = torch_compile_mode |
0 commit comments