|
24 | 24 | CustomResource, |
25 | 25 | Tags, |
26 | 26 | ) |
| 27 | +from constructs import DependencyGroup |
| 28 | +from botocore.exceptions import ClientError |
27 | 29 |
|
28 | 30 | from .manager import stack |
29 | 31 | from .policies.data_policy import DataPolicy |
30 | 32 | from .policies.service_policy import ServicePolicy |
31 | 33 | from ... import db |
32 | 34 | from ...aws.handlers.quicksight import Quicksight |
| 35 | +from ...aws.handlers.parameter_store import ParameterStoreManager |
33 | 36 | from ...aws.handlers.sagemaker_studio import ( |
34 | 37 | SagemakerStudio, |
35 | 38 | ) |
@@ -67,12 +70,23 @@ def init_quicksight(self, environment: models.Environment): |
67 | 70 |
|
68 | 71 | def check_sagemaker_studio(self, engine, environment: models.Environment): |
69 | 72 | logger.info('check sagemaker studio domain creation') |
70 | | - existing_domain = SagemakerStudio.get_sagemaker_studio_domain( |
71 | | - environment.AwsAccountId, environment.region |
72 | | - ) |
73 | | - existing_domain_id = existing_domain.get('DomainId', False) |
74 | | - if existing_domain_id: |
75 | | - return existing_domain_id |
| 73 | + |
| 74 | + try: |
| 75 | + dataall_created_domain = ParameterStoreManager.client( |
| 76 | + AwsAccountId=environment.AwsAccountId, |
| 77 | + region=environment.region |
| 78 | + ).get_parameter( |
| 79 | + Name=f'/dataall/{environment.environmentUri}/sagemaker/sagemakerstudio/domain_id' |
| 80 | + ) |
| 81 | + return None |
| 82 | + except ClientError as e: |
| 83 | + logger.info(f'check sagemaker studio domain created outside of data.all. Parameter data.all not found: {e}') |
| 84 | + existing_domain = SagemakerStudio.get_sagemaker_studio_domain( |
| 85 | + environment.AwsAccountId, environment.region |
| 86 | + ) |
| 87 | + existing_domain_id = existing_domain.get('DomainId', False) |
| 88 | + if existing_domain_id: |
| 89 | + return existing_domain_id |
76 | 90 |
|
77 | 91 | @staticmethod |
78 | 92 | def get_environment_group_permissions(engine, environmentUri, group): |
@@ -167,93 +181,16 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): |
167 | 181 | self.engine, self._environment |
168 | 182 | ) |
169 | 183 |
|
170 | | - self.sagemaker_domain_exists = self.check_sagemaker_studio(engine=self.engine, environment=self._environment) |
171 | | - |
172 | | - if self._environment.mlStudiosEnabled and not (self.sagemaker_domain_exists): |
173 | | - |
174 | | - sagemaker_domain_role = iam.Role( |
175 | | - self, |
176 | | - 'RoleForSagemakerStudioUsers', |
177 | | - assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'), |
178 | | - role_name="RoleSagemakerStudioUsers", |
179 | | - managed_policies=[iam.ManagedPolicy.from_managed_policy_arn( |
180 | | - self, |
181 | | - id="SagemakerFullAccess", |
182 | | - managed_policy_arn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess"), |
183 | | - iam.ManagedPolicy.from_managed_policy_arn( |
184 | | - self, |
185 | | - id="S3FullAccess", |
186 | | - managed_policy_arn="arn:aws:iam::aws:policy/AmazonS3FullAccess") |
187 | | - ] |
188 | | - ) |
189 | | - |
190 | | - sagemaker_domain_key = kms.Key( |
191 | | - self, |
192 | | - 'SagemakerDomainKmsKey', |
193 | | - alias="SagemakerStudioDomain", |
194 | | - enable_key_rotation=True, |
195 | | - policy=iam.PolicyDocument( |
196 | | - assign_sids=True, |
197 | | - statements=[ |
198 | | - iam.PolicyStatement( |
199 | | - resources=['*'], |
200 | | - effect=iam.Effect.ALLOW, |
201 | | - principals=[ |
202 | | - iam.AccountPrincipal(account_id=self._environment.AwsAccountId), |
203 | | - iam.Role.from_role_arn(self, 'DomainRole', role_arn=sagemaker_domain_role.role_arn), |
204 | | - iam.Role.from_role_arn(self, 'EnvironmentDefaultRole', role_arn=self.environment_admins_group.environmentIAMRoleArn), |
205 | | - ] + [iam.Role.from_role_arn(self, f'{group.groupUri}Role', role_arn=group.environmentIAMRoleArn) for group in self.environment_groups], |
206 | | - actions=['kms:*'], |
207 | | - ) |
208 | | - ], |
209 | | - ), |
210 | | - ) |
211 | | - |
212 | | - try: |
213 | | - default_vpc = ec2.Vpc.from_lookup(self, 'VPCStudio', is_default=True) |
214 | | - vpc_id = default_vpc.vpc_id |
215 | | - subnet_ids = [private_subnet.subnet_id for private_subnet in default_vpc.private_subnets] |
216 | | - subnet_ids += [public_subnet.subnet_id for public_subnet in default_vpc.public_subnets] |
217 | | - subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in default_vpc.isolated_subnets] |
218 | | - except Exception as e: |
219 | | - logger.error(f"Default VPC not found, Exception: {e}. If you don't own a default VPC, modify the networking configuration, or disable ML Studio upon environment creation.") |
220 | | - |
221 | | - sagemaker_domain = sagemaker.CfnDomain( |
222 | | - self, |
223 | | - "SagemakerStudioDomain", |
224 | | - domain_name=f"SagemakerStudioDomain-{self._environment.region}-{self._environment.AwsAccountId}", |
225 | | - auth_mode="IAM", |
226 | | - |
227 | | - default_user_settings=sagemaker.CfnDomain.UserSettingsProperty( |
228 | | - execution_role=sagemaker_domain_role.role_arn, |
229 | | - |
230 | | - security_groups=[], |
231 | | - |
232 | | - sharing_settings=sagemaker.CfnDomain.SharingSettingsProperty( |
233 | | - notebook_output_option="Allowed", |
234 | | - s3_kms_key_id=sagemaker_domain_key.key_id, |
235 | | - s3_output_path=f"s3://sagemaker-{self._environment.region}-{self._environment.AwsAccountId}", |
236 | | - ) |
237 | | - ), |
238 | | - |
239 | | - vpc_id=vpc_id, |
240 | | - subnet_ids=subnet_ids, |
241 | | - app_network_access_type="VpcOnly", |
242 | | - kms_key_id=sagemaker_domain_key.key_id, |
243 | | - ) |
244 | | - |
245 | | - ssm.StringParameter( |
246 | | - self, |
247 | | - 'SagemakerStudioDomainId', |
248 | | - string_value=sagemaker_domain.attr_domain_id, |
249 | | - parameter_name=f'/datahub/{self._environment.environmentUri}/sagemaker/sagemakerstudio/domain_id', |
250 | | - ) |
| 184 | + roles_sagemaker_dependency_group = DependencyGroup() |
251 | 185 |
|
252 | 186 | if self._environment.dashboardsEnabled: |
253 | 187 | logger.warning('ensure_quicksight_default_group') |
254 | 188 | self.init_quicksight(environment=self._environment) |
255 | 189 |
|
256 | | - self.create_or_import_environment_groups_roles() |
| 190 | + group_roles = self.create_or_import_environment_groups_roles() |
| 191 | + |
| 192 | + for group_role in group_roles: |
| 193 | + roles_sagemaker_dependency_group.add(group_role) |
257 | 194 |
|
258 | 195 | central_account = SessionHelper.get_account() |
259 | 196 |
|
@@ -330,7 +267,8 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): |
330 | 267 | destination_key_prefix='profiling/code', |
331 | 268 | ) |
332 | 269 |
|
333 | | - self.create_or_import_environment_default_role() |
| 270 | + default_role = self.create_or_import_environment_default_role() |
| 271 | + roles_sagemaker_dependency_group.add(default_role) |
334 | 272 |
|
335 | 273 | self.create_default_athena_workgroup( |
336 | 274 | default_environment_bucket, |
@@ -568,6 +506,89 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): |
568 | 506 | self._environment, |
569 | 507 | ) |
570 | 508 |
|
| 509 | + self.sagemaker_domain_exists = self.check_sagemaker_studio(engine=self.engine, environment=self._environment) |
| 510 | + |
| 511 | + if self._environment.mlStudiosEnabled and not (self.sagemaker_domain_exists): |
| 512 | + |
| 513 | + sagemaker_domain_role = iam.Role( |
| 514 | + self, |
| 515 | + 'RoleForSagemakerStudioUsers', |
| 516 | + assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'), |
| 517 | + role_name="RoleSagemakerStudioUsers", |
| 518 | + managed_policies=[iam.ManagedPolicy.from_managed_policy_arn( |
| 519 | + self, |
| 520 | + id="SagemakerFullAccess", |
| 521 | + managed_policy_arn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess"), |
| 522 | + iam.ManagedPolicy.from_managed_policy_arn( |
| 523 | + self, |
| 524 | + id="S3FullAccess", |
| 525 | + managed_policy_arn="arn:aws:iam::aws:policy/AmazonS3FullAccess") |
| 526 | + ] |
| 527 | + ) |
| 528 | + |
| 529 | + sagemaker_domain_key = kms.Key( |
| 530 | + self, |
| 531 | + 'SagemakerDomainKmsKey', |
| 532 | + alias="SagemakerStudioDomain", |
| 533 | + enable_key_rotation=True, |
| 534 | + policy=iam.PolicyDocument( |
| 535 | + assign_sids=True, |
| 536 | + statements=[ |
| 537 | + iam.PolicyStatement( |
| 538 | + resources=['*'], |
| 539 | + effect=iam.Effect.ALLOW, |
| 540 | + principals=[ |
| 541 | + iam.AccountPrincipal(account_id=self._environment.AwsAccountId), |
| 542 | + sagemaker_domain_role, |
| 543 | + default_role, |
| 544 | + ] + group_roles, |
| 545 | + actions=['kms:*'], |
| 546 | + ) |
| 547 | + ], |
| 548 | + ), |
| 549 | + ) |
| 550 | + sagemaker_domain_key.node.add_dependency(roles_sagemaker_dependency_group) |
| 551 | + |
| 552 | + try: |
| 553 | + default_vpc = ec2.Vpc.from_lookup(self, 'VPCStudio', is_default=True) |
| 554 | + vpc_id = default_vpc.vpc_id |
| 555 | + subnet_ids = [private_subnet.subnet_id for private_subnet in default_vpc.private_subnets] |
| 556 | + subnet_ids += [public_subnet.subnet_id for public_subnet in default_vpc.public_subnets] |
| 557 | + subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in default_vpc.isolated_subnets] |
| 558 | + except Exception as e: |
| 559 | + logger.error(f"Default VPC not found, Exception: {e}. If you don't own a default VPC, modify the networking configuration, or disable ML Studio upon environment creation.") |
| 560 | + |
| 561 | + sagemaker_domain = sagemaker.CfnDomain( |
| 562 | + self, |
| 563 | + "SagemakerStudioDomain", |
| 564 | + domain_name=f"SagemakerStudioDomain-{self._environment.region}-{self._environment.AwsAccountId}", |
| 565 | + auth_mode="IAM", |
| 566 | + |
| 567 | + default_user_settings=sagemaker.CfnDomain.UserSettingsProperty( |
| 568 | + execution_role=sagemaker_domain_role.role_arn, |
| 569 | + |
| 570 | + security_groups=[], |
| 571 | + |
| 572 | + sharing_settings=sagemaker.CfnDomain.SharingSettingsProperty( |
| 573 | + notebook_output_option="Allowed", |
| 574 | + s3_kms_key_id=sagemaker_domain_key.key_id, |
| 575 | + s3_output_path=f"s3://sagemaker-{self._environment.region}-{self._environment.AwsAccountId}", |
| 576 | + ) |
| 577 | + ), |
| 578 | + |
| 579 | + vpc_id=vpc_id, |
| 580 | + subnet_ids=subnet_ids, |
| 581 | + app_network_access_type="VpcOnly", |
| 582 | + kms_key_id=sagemaker_domain_key.key_id, |
| 583 | + ) |
| 584 | + |
| 585 | + ssm.StringParameter( |
| 586 | + self, |
| 587 | + 'SagemakerStudioDomainId', |
| 588 | + string_value=sagemaker_domain.attr_domain_id, |
| 589 | + parameter_name=f'/dataall/{self._environment.environmentUri}/sagemaker/sagemakerstudio/domain_id', |
| 590 | + ) |
| 591 | + |
571 | 592 | TagsUtil.add_tags(self) |
572 | 593 |
|
573 | 594 | CDKNagUtil.check_rules(self) |
@@ -635,15 +656,18 @@ def create_or_import_environment_default_role(self): |
635 | 656 |
|
636 | 657 | def create_or_import_environment_groups_roles(self): |
637 | 658 | group: models.EnvironmentGroup |
| 659 | + group_roles = [] |
638 | 660 | for group in self.environment_groups: |
639 | 661 | if not group.environmentIAMRoleImported: |
640 | | - self.create_group_environment_role(group) |
| 662 | + group_role = self.create_group_environment_role(group) |
| 663 | + group_roles.append(group_role) |
641 | 664 | else: |
642 | 665 | iam.Role.from_role_arn( |
643 | 666 | self, |
644 | 667 | f'{group.groupUri + group.environmentIAMRoleName}', |
645 | 668 | role_arn=f'arn:aws:iam::{self.environment.AwsAccountId}:role/{group.environmentIAMRoleName}', |
646 | 669 | ) |
| 670 | + return group_roles |
647 | 671 |
|
648 | 672 | def create_group_environment_role(self, group): |
649 | 673 |
|
|
0 commit comments