@@ -74,35 +74,40 @@ def connection(self) -> "PostgresConnection":
7474
7575
7676class Record (Mapping ):
77+ __slots__ = (
78+ "_row" ,
79+ "_result_columns" ,
80+ "_dialect" ,
81+ "_column_map" ,
82+ "_column_map_int" ,
83+ "_column_map_full" ,
84+ )
85+
7786 def __init__ (
78- self , row : asyncpg .Record , result_columns : tuple , dialect : Dialect
87+ self ,
88+ row : asyncpg .Record ,
89+ result_columns : tuple ,
90+ dialect : Dialect ,
91+ column_maps : typing .Tuple [
92+ typing .Mapping [typing .Any , typing .Tuple [int , TypeEngine ]],
93+ typing .Mapping [int , typing .Tuple [int , TypeEngine ]],
94+ typing .Mapping [str , typing .Tuple [int , TypeEngine ]],
95+ ],
7996 ) -> None :
8097 self ._row = row
8198 self ._result_columns = result_columns
8299 self ._dialect = dialect
83- self ._column_map = (
84- {}
85- ) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]]
86- self ._column_map_int = (
87- {}
88- ) # type: typing.Mapping[int, typing.Tuple[int, TypeEngine]]
89- self ._column_map_full = (
90- {}
91- ) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]]
92- for idx , (column_name , _ , column , datatype ) in enumerate (self ._result_columns ):
93- self ._column_map [column_name ] = (idx , datatype )
94- self ._column_map_int [idx ] = (idx , datatype )
95- self ._column_map_full [str (column [0 ])] = (idx , datatype )
100+ self ._column_map , self ._column_map_int , self ._column_map_full = column_maps
96101
97102 def values (self ) -> typing .ValuesView :
98103 return self ._row .values ()
99104
100105 def __getitem__ (self , key : typing .Any ) -> typing .Any :
101106 if len (self ._column_map ) == 0 : # raw query
102107 return self ._row [tuple (self ._row .keys ()).index (key )]
103- elif type (key ) is Column :
108+ elif isinstance (key , Column ) :
104109 idx , datatype = self ._column_map_full [str (key )]
105- elif type (key ) is int :
110+ elif isinstance (key , int ) :
106111 idx , datatype = self ._column_map_int [key ]
107112 else :
108113 idx , datatype = self ._column_map [key ]
@@ -145,15 +150,22 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
145150 assert self ._connection is not None , "Connection is not acquired"
146151 query , args , result_columns = self ._compile (query )
147152 rows = await self ._connection .fetch (query , * args )
148- return [Record (row , result_columns , self ._dialect ) for row in rows ]
153+ dialect = self ._dialect
154+ column_maps = self ._create_column_maps (result_columns )
155+ return [Record (row , result_columns , dialect , column_maps ) for row in rows ]
149156
150157 async def fetch_one (self , query : ClauseElement ) -> typing .Optional [typing .Mapping ]:
151158 assert self ._connection is not None , "Connection is not acquired"
152159 query , args , result_columns = self ._compile (query )
153160 row = await self ._connection .fetchrow (query , * args )
154161 if row is None :
155162 return None
156- return Record (row , result_columns , self ._dialect )
163+ return Record (
164+ row ,
165+ result_columns ,
166+ self ._dialect ,
167+ self ._create_column_maps (result_columns ),
168+ )
157169
158170 async def fetch_val (
159171 self , query : ClauseElement , column : typing .Any = 0
@@ -181,8 +193,9 @@ async def iterate(
181193 ) -> typing .AsyncGenerator [typing .Any , None ]:
182194 assert self ._connection is not None , "Connection is not acquired"
183195 query , args , result_columns = self ._compile (query )
196+ column_maps = self ._create_column_maps (result_columns )
184197 async for row in self ._connection .cursor (query , * args ):
185- yield Record (row , result_columns , self ._dialect )
198+ yield Record (row , result_columns , self ._dialect , column_maps )
186199
187200 def transaction (self ) -> TransactionBackend :
188201 return PostgresTransaction (connection = self )
@@ -208,6 +221,34 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
208221 )
209222 return compiled_query , args , compiled ._result_columns
210223
224+ @staticmethod
225+ def _create_column_maps (
226+ result_columns : tuple ,
227+ ) -> typing .Tuple [
228+ typing .Mapping [typing .Any , typing .Tuple [int , TypeEngine ]],
229+ typing .Mapping [int , typing .Tuple [int , TypeEngine ]],
230+ typing .Mapping [str , typing .Tuple [int , TypeEngine ]],
231+ ]:
232+ """
233+ Generate column -> datatype mappings from the column definitions.
234+
235+ These mappings are used throughout PostgresConnection methods
236+ to initialize Record-s. The underlying DB driver does not do type
237+ conversion for us so we have wrap the returned asyncpg.Record-s.
238+
239+ :return: Three mappings from different ways to address a column to \
240+ corresponding column indexes and datatypes: \
241+ 1. by column identifier; \
242+ 2. by column index; \
243+ 3. by column name in Column sqlalchemy objects.
244+ """
245+ column_map , column_map_int , column_map_full = {}, {}, {}
246+ for idx , (column_name , _ , column , datatype ) in enumerate (result_columns ):
247+ column_map [column_name ] = (idx , datatype )
248+ column_map_int [idx ] = (idx , datatype )
249+ column_map_full [str (column [0 ])] = (idx , datatype )
250+ return column_map , column_map_int , column_map_full
251+
211252 @property
212253 def raw_connection (self ) -> asyncpg .connection .Connection :
213254 assert self ._connection is not None , "Connection is not acquired"
0 commit comments