33from typing import TYPE_CHECKING , Generator , Iterable
44
55from domain .core .aggregate_root import AggregateRoot
6+ from domain .core .enum import EntityType
67from domain .repository .errors import ItemNotFound
78from domain .repository .keys import KEY_SEPARATOR , TableKey
89from domain .repository .marshall import marshall , unmarshall
@@ -106,6 +107,7 @@ def create_index(
106107 parent_key_parts : tuple [str ],
107108 data : dict ,
108109 root : bool ,
110+ row_type : str ,
109111 table_key : TableKey = None ,
110112 parent_table_keys : tuple [TableKey ] = None ,
111113 ) -> TransactItem :
@@ -119,30 +121,37 @@ def create_index(
119121 f"Expected provide { len (parent_table_keys )} parent key parts, got { len (parent_key_parts )} "
120122 )
121123
122- write_key = table_key .key (id )
123- read_key = KEY_SEPARATOR .join (
124+ sort_key = table_key .key (id )
125+ partition_key = KEY_SEPARATOR .join (
124126 table_key .key (_id )
125127 for table_key , _id in zip (parent_table_keys , parent_key_parts )
126128 )
127129
130+ item_data = {
131+ "pk" : partition_key ,
132+ "sk" : sort_key ,
133+ "pk_read_1" : sort_key ,
134+ "sk_read_1" : sort_key ,
135+ "root" : root ,
136+ "row_type" : row_type ,
137+ ** data ,
138+ }
139+
140+ if row_type != EntityType .PRODUCT_TEAM_ALIAS :
141+ item_data ["pk_read_2" ] = TableKey .ORG_CODE .key (data ["ods_code" ])
142+ item_data ["sk_read_2" ] = sort_key
143+
128144 return TransactItem (
129145 Put = TransactionStatement (
130146 TableName = self .table_name ,
131- Item = marshall (
132- pk = write_key ,
133- sk = write_key ,
134- pk_read_1 = read_key ,
135- sk_read_1 = write_key ,
136- root = root ,
137- ** data ,
138- ),
147+ Item = marshall (** item_data ),
139148 ConditionExpression = ConditionExpression .MUST_NOT_EXIST ,
140149 )
141150 )
142151
143- def update_indexes (self , id : str , keys : list [str ], data : dict ):
152+ def update_indexes (self , pk : str , id : str , keys : list [str ], data : dict ):
144153 primary_keys = [
145- marshall (pk = pk , sk = pk ) for pk in map (self .table_key .key , [id , * keys ])
154+ marshall (pk = pk , sk = sk ) for sk in map (self .table_key .key , [id , * keys ])
146155 ]
147156 return update_transactions (
148157 table_name = self .table_name , primary_keys = primary_keys , data = data
@@ -161,21 +170,18 @@ def delete_index(self, id: str):
161170 def _query (
162171 self , parent_ids : tuple [str ], id : str = None , status : str = "all"
163172 ) -> list [dict ]:
164- pk_read_1 = KEY_SEPARATOR .join (
173+ pk = KEY_SEPARATOR .join (
165174 table_key .key (_id )
166175 for table_key , _id in zip (self .parent_table_keys , parent_ids )
167176 )
168- sk_read_1 = self .table_key .key (id or "" )
177+ sk = self .table_key .key (id or "" )
169178
170179 sk_query_type = QueryType .BEGINS_WITH if id is None else QueryType .EQUALS
171- sk_condition = sk_query_type .format ("sk_read_1 " , ":sk_read_1 " )
180+ sk_condition = sk_query_type .format ("sk " , ":sk " )
172181 args = {
173182 "TableName" : self .table_name ,
174- "IndexName" : "idx_gsi_read_1" ,
175- "KeyConditionExpression" : f"pk_read_1 = :pk_read_1 AND { sk_condition } " ,
176- "ExpressionAttributeValues" : marshall (
177- ** {":pk_read_1" : pk_read_1 , ":sk_read_1" : sk_read_1 }
178- ),
183+ "KeyConditionExpression" : f"pk = :pk AND { sk_condition } " ,
184+ "ExpressionAttributeValues" : marshall (** {":pk" : pk , ":sk" : sk }),
179185 }
180186 if status != "all" :
181187 args ["FilterExpression" ] = "#status = :status"
@@ -199,5 +205,7 @@ def _read(self, parent_ids: tuple[str], id: str, status: str = "all") -> ModelTy
199205 try :
200206 (item ,) = items
201207 except ValueError :
208+ if id in parent_ids :
209+ raise ItemNotFound (id , item_type = self .model )
202210 raise ItemNotFound (* filter (bool , parent_ids ), id , item_type = self .model )
203211 return self .model (** item )
0 commit comments