Skip to content

Commit 4b102ee

Browse files
ajaykarpurlaurenyu
andauthored
fix: do not use script for TFS when entry_point is not provided (#1252)
* fix: Do not set script in prepare_framework_container_def for TFS * remove _build_airflow_workflow_tf * fix: Set script_name and dir_name env vars to None for TFS Co-authored-by: Lauren Yu <[email protected]>
1 parent ede60e7 commit 4b102ee

File tree

3 files changed

+18
-56
lines changed

3 files changed

+18
-56
lines changed

src/sagemaker/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,9 +844,12 @@ def _framework_env_vars(self):
844844
dir_name = "/opt/ml/model/code"
845845
else:
846846
dir_name = self.uploaded_code.s3_prefix
847-
else:
847+
elif self.entry_point is not None:
848848
script_name = self.entry_point
849849
dir_name = "file://" + self.source_dir
850+
else:
851+
script_name = None
852+
dir_name = None
850853

851854
return {
852855
SCRIPT_PARAM_NAME.upper(): script_name,

src/sagemaker/workflow/airflow.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -522,18 +522,19 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
522522
model.name = model.name or utils.name_from_base(base_name)
523523

524524
bucket = model.bucket or model.sagemaker_session._default_bucket
525-
script = os.path.basename(model.entry_point)
526-
key = "{}/source/sourcedir.tar.gz".format(model.name)
527-
528-
if model.source_dir and model.source_dir.lower().startswith("s3://"):
529-
code_dir = model.source_dir
530-
model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
531-
else:
532-
code_dir = "s3://{}/{}".format(bucket, key)
533-
model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
534-
s3_operations["S3Upload"] = [
535-
{"Path": model.source_dir or script, "Bucket": bucket, "Key": key, "Tar": True}
536-
]
525+
if model.entry_point is not None:
526+
script = os.path.basename(model.entry_point)
527+
key = "{}/source/sourcedir.tar.gz".format(model.name)
528+
529+
if model.source_dir and model.source_dir.lower().startswith("s3://"):
530+
code_dir = model.source_dir
531+
model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
532+
else:
533+
code_dir = "s3://{}/{}".format(bucket, key)
534+
model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
535+
s3_operations["S3Upload"] = [
536+
{"Path": model.source_dir or script, "Bucket": bucket, "Key": key, "Tar": True}
537+
]
537538

538539
deploy_env = dict(model.env)
539540
deploy_env.update(model._framework_env_vars())

tests/integ/test_airflow_config.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def test_tf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
572572
path=os.path.join(TF_MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
573573
)
574574

575-
training_config = _build_airflow_workflow_tf(
575+
training_config = _build_airflow_workflow(
576576
estimator=tf, instance_type=cpu_instance_type, inputs=inputs
577577
)
578578

@@ -701,45 +701,3 @@ def _build_airflow_workflow(estimator, instance_type, inputs=None, mini_batch_si
701701
transform_op.set_upstream(train_op)
702702

703703
return training_config
704-
705-
706-
def _build_airflow_workflow_tf(estimator, instance_type, inputs=None, mini_batch_size=None):
707-
training_config = sm_airflow.training_config(
708-
estimator=estimator, inputs=inputs, mini_batch_size=mini_batch_size
709-
)
710-
711-
model = estimator.create_model(entry_point=estimator.entry_point)
712-
assert model is not None
713-
714-
model_config = sm_airflow.model_config(instance_type, model)
715-
assert model_config is not None
716-
717-
transform_config = sm_airflow.transform_config_from_estimator(
718-
estimator=estimator,
719-
task_id="transform_config",
720-
task_type="training",
721-
instance_count=SINGLE_INSTANCE_COUNT,
722-
instance_type=estimator.train_instance_type,
723-
data=inputs,
724-
content_type="text/csv",
725-
)
726-
727-
default_args = {
728-
"owner": "airflow",
729-
"start_date": airflow.utils.dates.days_ago(2),
730-
"provide_context": True,
731-
}
732-
733-
dag = DAG("tensorflow_example", default_args=default_args, schedule_interval="@once")
734-
735-
train_op = SageMakerTrainingOperator(
736-
task_id="tf_training", config=training_config, wait_for_completion=True, dag=dag
737-
)
738-
739-
transform_op = SageMakerTransformOperator(
740-
task_id="transform_operator", config=transform_config, wait_for_completion=True, dag=dag
741-
)
742-
743-
transform_op.set_upstream(train_op)
744-
745-
return training_config

0 commit comments

Comments
 (0)