|
74 | 74 | "sagemaker_submit_directory": json.dumps("file:///tmp/code"), |
75 | 75 | } |
76 | 76 |
|
| 77 | +ENVIRONMENT = {"MYVAR": "HELLO_WORLD"} |
| 78 | + |
77 | 79 |
|
78 | 80 | @pytest.fixture() |
79 | 81 | def sagemaker_session(): |
@@ -352,7 +354,7 @@ def test_train( |
352 | 354 | "local", instance_count, image, sagemaker_session=sagemaker_session |
353 | 355 | ) |
354 | 356 | sagemaker_container.train( |
355 | | - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME |
| 357 | + INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, ENVIRONMENT, TRAINING_JOB_NAME |
356 | 358 | ) |
357 | 359 |
|
358 | 360 | docker_compose_file = os.path.join( |
@@ -415,7 +417,7 @@ def test_train_with_hyperparameters_without_job_name( |
415 | 417 | "local", instance_count, image, sagemaker_session=sagemaker_session |
416 | 418 | ) |
417 | 419 | sagemaker_container.train( |
418 | | - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME |
| 420 | + INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, ENVIRONMENT, TRAINING_JOB_NAME |
419 | 421 | ) |
420 | 422 |
|
421 | 423 | docker_compose_file = os.path.join( |
@@ -456,7 +458,11 @@ def test_train_error( |
456 | 458 |
|
457 | 459 | with pytest.raises(RuntimeError) as e: |
458 | 460 | sagemaker_container.train( |
459 | | - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME |
| 461 | + INPUT_DATA_CONFIG, |
| 462 | + OUTPUT_DATA_CONFIG, |
| 463 | + HYPERPARAMETERS, |
| 464 | + ENVIRONMENT, |
| 465 | + TRAINING_JOB_NAME, |
460 | 466 | ) |
461 | 467 |
|
462 | 468 | assert "this is expected" in str(e) |
@@ -486,7 +492,11 @@ def test_train_local_code(get_data_source_instance, tmpdir, sagemaker_session): |
486 | 492 | ) |
487 | 493 |
|
488 | 494 | sagemaker_container.train( |
489 | | - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, LOCAL_CODE_HYPERPARAMETERS, TRAINING_JOB_NAME |
| 495 | + INPUT_DATA_CONFIG, |
| 496 | + OUTPUT_DATA_CONFIG, |
| 497 | + LOCAL_CODE_HYPERPARAMETERS, |
| 498 | + ENVIRONMENT, |
| 499 | + TRAINING_JOB_NAME, |
490 | 500 | ) |
491 | 501 |
|
492 | 502 | docker_compose_file = os.path.join( |
@@ -538,7 +548,7 @@ def test_train_local_intermediate_output(get_data_source_instance, tmpdir, sagem |
538 | 548 | hyperparameters = {"sagemaker_s3_output": output_path} |
539 | 549 |
|
540 | 550 | sagemaker_container.train( |
541 | | - INPUT_DATA_CONFIG, output_data_config, hyperparameters, TRAINING_JOB_NAME |
| 551 | + INPUT_DATA_CONFIG, output_data_config, hyperparameters, ENVIRONMENT, TRAINING_JOB_NAME |
542 | 552 | ) |
543 | 553 |
|
544 | 554 | docker_compose_file = os.path.join( |
|
0 commit comments