|
37 | 37 | "me-south-1", |
38 | 38 | ] |
39 | 39 |
|
| 40 | +NO_P3_REGIONS = [ |
| 41 | + "af-south-1", |
| 42 | + "ap-east-1", |
| 43 | + "ap-southeast-1", # it has p3, but not enough |
| 44 | + "ap-southeast-2", # it has p3, but not enough |
| 45 | + "ca-central-1", # it has p3, but not enough |
| 46 | + "eu-central-1", # it has p3, but not enough |
| 47 | + "eu-north-1", |
| 48 | + "eu-west-2", # it has p3, but not enough |
| 49 | + "eu-west-3", |
| 50 | + "eu-south-1", |
| 51 | + "me-south-1", |
| 52 | + "sa-east-1", |
| 53 | + "us-west-1", |
| 54 | + "ap-northeast-1", # it has p3, but not enough |
| 55 | + "ap-south-1", |
| 56 | + "ap-northeast-2", # it has p3, but not enough |
| 57 | +] |
| 58 | + |
40 | 59 | NO_T2_REGIONS = ["eu-north-1", "ap-east-1", "me-south-1"] |
41 | 60 |
|
42 | 61 | FRAMEWORKS_FOR_GENERATED_VERSION_FIXTURES = ( |
@@ -361,9 +380,13 @@ def cpu_instance_type(sagemaker_session, request): |
361 | 380 | return "ml.m4.xlarge" |
362 | 381 |
|
363 | 382 |
|
364 | | -@pytest.fixture(scope="module") |
365 | | -def gpu_instance_type(request): |
366 | | - return "ml.p3.2xlarge" |
| 383 | +@pytest.fixture(scope="session") |
| 384 | +def gpu_instance_type(sagemaker_session, request): |
| 385 | + region = sagemaker_session.boto_session.region_name |
| 386 | + if region in NO_P3_REGIONS: |
| 387 | + return "ml.p2.xlarge" |
| 388 | + else: |
| 389 | + return "ml.p3.2xlarge" |
367 | 390 |
|
368 | 391 |
|
369 | 392 | @pytest.fixture(scope="session") |
@@ -405,10 +428,16 @@ def pytest_generate_tests(metafunc): |
405 | 428 |
|
406 | 429 | params = [cpu_instance_type] |
407 | 430 | if not ( |
| 431 | + region in tests.integ.HOSTING_NO_P3_REGIONS |
| 432 | + or region in tests.integ.TRAINING_NO_P3_REGIONS |
| 433 | + ): |
| 434 | + params.append("ml.p3.2xlarge") |
| 435 | + elif not ( |
408 | 436 | region in tests.integ.HOSTING_NO_P2_REGIONS |
409 | 437 | or region in tests.integ.TRAINING_NO_P2_REGIONS |
410 | 438 | ): |
411 | | - params.append("ml.p3.2xlarge") |
| 439 | + params.append("ml.p2.xlarge") |
| 440 | + |
412 | 441 | metafunc.parametrize("instance_type", params, scope="session") |
413 | 442 |
|
414 | 443 | _generate_all_framework_version_fixtures(metafunc) |
|
0 commit comments