|
41 | 41 | "us-west-1": "763104351884", |
42 | 42 | "us-west-2": "763104351884", |
43 | 43 | } |
44 | | -VERSIONS = ["0.21.0", "0.20.0", "0.19.0"] |
45 | | -DJL_FRAMEWORKS = ["djl-deepspeed"] |
| 44 | +DJL_DEEPSPEED_VERSIONS = ["0.21.0", "0.20.0", "0.19.0"] |
| 45 | +DJL_FASTERTRANSFORMER_VERSIONS = ["0.21.0"] |
46 | 46 | DJL_VERSIONS_TO_FRAMEWORK = { |
47 | 47 | "0.19.0": {"djl-deepspeed": "deepspeed0.7.3-cu113"}, |
48 | 48 | "0.20.0": {"djl-deepspeed": "deepspeed0.7.5-cu116"}, |
49 | | - "0.21.0": {"djl-deepspeed": "deepspeed0.8.0-cu117"}, |
| 49 | + "0.21.0": { |
| 50 | + "djl-deepspeed": "deepspeed0.8.0-cu117", |
| 51 | + "djl-fastertransformer": "fastertransformer5.3.0-cu117", |
| 52 | + }, |
50 | 53 | } |
51 | 54 |
|
52 | 55 |
|
53 | 56 | @pytest.mark.parametrize("region", ACCOUNTS.keys()) |
54 | | -@pytest.mark.parametrize("version", VERSIONS) |
55 | | -@pytest.mark.parametrize("djl_framework", DJL_FRAMEWORKS) |
56 | | -def test_djl_uris(region, version, djl_framework): |
| 57 | +@pytest.mark.parametrize("version", DJL_DEEPSPEED_VERSIONS) |
| 58 | +def test_djl_deepspeed(region, version): |
| 59 | + _test_djl_uris(region, version, "djl-deepspeed") |
| 60 | + |
| 61 | + |
| 62 | +@pytest.mark.parametrize("region", ACCOUNTS.keys()) |
| 63 | +@pytest.mark.parametrize("version", DJL_FASTERTRANSFORMER_VERSIONS) |
| 64 | +def test_djl_fastertransformer(region, version): |
| 65 | + _test_djl_uris(region, version, "djl-fastertransformer") |
| 66 | + |
| 67 | + |
| 68 | +def _test_djl_uris(region, version, djl_framework): |
57 | 69 | uri = image_uris.retrieve(framework=djl_framework, region=region, version=version) |
58 | 70 | expected = expected_uris.djl_framework_uri( |
59 | 71 | "djl-inference", |
|
0 commit comments