@@ -205,6 +205,40 @@ def _log(self):
205205 def external (self ):
206206 return self .connection .schemas [self .database ].external
207207
208+ def update1 (self , row ):
209+ """
210+ Update an existing entry in the table.
211+ Caution: Updates are not part of the DataJoint data manipulation model. For strict data integrity,
212+ use delete and insert.
213+ :param row: a dict containing the primary key and the attributes to update.
214+ Setting an attribute value to None will reset it to the default value (if any)
215+ The primary key attributes must always be provided.
216+ Examples:
217+ >>> table.update1({'id': 1, 'value': 3}) # update value in record with id=1
218+ >>> table.update1({'id': 1, 'value': None}) # reset value to default
219+ """
220+ # argument validations
221+ if not isinstance (row , collections .abc .Mapping ):
222+ raise DataJointError ('The argument of update1 must be dict-like.' )
223+ if not set (row ).issuperset (self .primary_key ):
224+ raise DataJointError ('The argument of update1 must supply all primary key values.' )
225+ try :
226+ raise DataJointError ('Attribute `%s` not found.' % next (k for k in row if k not in self .heading .names ))
227+ except StopIteration :
228+ pass # ok
229+ if len (self .restriction ):
230+ raise DataJointError ('Update cannot be applied to a restricted table.' )
231+ key = {k : row [k ] for k in self .primary_key }
232+ if len (self & key ) != 1 :
233+ raise DataJointError ('Update entry must exist.' )
234+ # UPDATE query
235+ row = [self .__make_placeholder (k , v ) for k , v in row .items () if k not in self .primary_key ]
236+ query = "UPDATE {table} SET {assignments} WHERE {where}" .format (
237+ table = self .full_table_name ,
238+ assignments = "," .join ('`%s`=%s' % r [:2 ] for r in row ),
239+ where = self ._make_condition (key ))
240+ self .connection .query (query , args = list (r [2 ] for r in row if r [2 ] is not None ))
241+
208242 def insert1 (self , row , ** kwargs ):
209243 """
210244 Insert one data record or one Mapping (like a dict).
@@ -244,7 +278,6 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
244278 'Inserts into an auto-populated table can only done inside its make method during a populate call.'
245279 ' To override, set keyword argument allow_direct_insert=True.' )
246280
247- heading = self .heading
248281 if inspect .isclass (rows ) and issubclass (rows , QueryExpression ): # instantiate if a class
249282 rows = rows ()
250283 if isinstance (rows , QueryExpression ):
@@ -253,10 +286,10 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
253286 try :
254287 raise DataJointError (
255288 "Attribute %s not found. To ignore extra attributes in insert, set ignore_extra_fields=True." %
256- next (name for name in rows .heading if name not in heading ))
289+ next (name for name in rows .heading if name not in self . heading ))
257290 except StopIteration :
258291 pass
259- fields = list (name for name in rows .heading if name in heading )
292+ fields = list (name for name in rows .heading if name in self . heading )
260293 query = '{command} INTO {table} ({fields}) {select}{duplicate}' .format (
261294 command = 'REPLACE' if replace else 'INSERT' ,
262295 fields = '`' + '`,`' .join (fields ) + '`' ,
@@ -268,110 +301,8 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
268301 self .connection .query (query )
269302 return
270303
271- if heading .attributes is None :
272- logger .warning ('Could not access table {table}' .format (table = self .full_table_name ))
273- return
274-
275- field_list = None # ensures that all rows have the same attributes in the same order as the first row.
276-
277- def make_row_to_insert (row ):
278- """
279- :param row: A tuple to insert
280- :return: a dict with fields 'names', 'placeholders', 'values'
281- """
282- def make_placeholder (name , value ):
283- """
284- For a given attribute `name` with `value`, return its processed value or value placeholder
285- as a string to be included in the query and the value, if any, to be submitted for
286- processing by mysql API.
287- :param name: name of attribute to be inserted
288- :param value: value of attribute to be inserted
289- """
290- if ignore_extra_fields and name not in heading :
291- return None
292- attr = heading [name ]
293- if attr .adapter :
294- value = attr .adapter .put (value )
295- if value is None or (attr .numeric and (value == '' or np .isnan (np .float (value )))):
296- # set default value
297- placeholder , value = 'DEFAULT' , None
298- else : # not NULL
299- placeholder = '%s'
300- if attr .uuid :
301- if not isinstance (value , uuid .UUID ):
302- try :
303- value = uuid .UUID (value )
304- except (AttributeError , ValueError ):
305- raise DataJointError (
306- 'badly formed UUID value {v} for attribute `{n}`' .format (v = value , n = name ))
307- value = value .bytes
308- elif attr .is_blob :
309- value = blob .pack (value )
310- value = self .external [attr .store ].put (value ).bytes if attr .is_external else value
311- elif attr .is_attachment :
312- attachment_path = Path (value )
313- if attr .is_external :
314- # value is hash of contents
315- value = self .external [attr .store ].upload_attachment (attachment_path ).bytes
316- else :
317- # value is filename + contents
318- value = str .encode (attachment_path .name ) + b'\0 ' + attachment_path .read_bytes ()
319- elif attr .is_filepath :
320- value = self .external [attr .store ].upload_filepath (value ).bytes
321- elif attr .numeric :
322- value = str (int (value ) if isinstance (value , bool ) else value )
323- return name , placeholder , value
324-
325- def check_fields (fields ):
326- """
327- Validates that all items in `fields` are valid attributes in the heading
328- :param fields: field names of a tuple
329- """
330- if field_list is None :
331- if not ignore_extra_fields :
332- for field in fields :
333- if field not in heading :
334- raise KeyError (u'`{0:s}` is not in the table heading' .format (field ))
335- elif set (field_list ) != set (fields ).intersection (heading .names ):
336- raise DataJointError ('Attempt to insert rows with different fields' )
337-
338- if isinstance (row , np .void ): # np.array
339- check_fields (row .dtype .fields )
340- attributes = [make_placeholder (name , row [name ])
341- for name in heading if name in row .dtype .fields ]
342- elif isinstance (row , collections .abc .Mapping ): # dict-based
343- check_fields (row )
344- attributes = [make_placeholder (name , row [name ]) for name in heading if name in row ]
345- else : # positional
346- try :
347- if len (row ) != len (heading ):
348- raise DataJointError (
349- 'Invalid insert argument. Incorrect number of attributes: '
350- '{given} given; {expected} expected' .format (
351- given = len (row ), expected = len (heading )))
352- except TypeError :
353- raise DataJointError ('Datatype %s cannot be inserted' % type (row ))
354- else :
355- attributes = [make_placeholder (name , value ) for name , value in zip (heading , row )]
356- if ignore_extra_fields :
357- attributes = [a for a in attributes if a is not None ]
358-
359- assert len (attributes ), 'Empty tuple'
360- row_to_insert = dict (zip (('names' , 'placeholders' , 'values' ), zip (* attributes )))
361- nonlocal field_list
362- if field_list is None :
363- # first row sets the composition of the field list
364- field_list = row_to_insert ['names' ]
365- else :
366- # reorder attributes in row_to_insert to match field_list
367- order = list (row_to_insert ['names' ].index (field ) for field in field_list )
368- row_to_insert ['names' ] = list (row_to_insert ['names' ][i ] for i in order )
369- row_to_insert ['placeholders' ] = list (row_to_insert ['placeholders' ][i ] for i in order )
370- row_to_insert ['values' ] = list (row_to_insert ['values' ][i ] for i in order )
371-
372- return row_to_insert
373-
374- rows = list (make_row_to_insert (row ) for row in rows )
304+ field_list = [] # collects the field list from first row (passed by reference)
305+ rows = list (self .__make_row_to_insert (row , field_list , ignore_extra_fields ) for row in rows )
375306 if rows :
376307 try :
377308 query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}" .format (
@@ -638,6 +569,107 @@ def _update(self, attrname, value=None):
638569 where_clause = self .where_clause )
639570 self .connection .query (command , args = (value , ) if value is not None else ())
640571
572+ # --- private helper functions ----
573+ def __make_placeholder (self , name , value , ignore_extra_fields = False ):
574+ """
575+ For a given attribute `name` with `value`, return its processed value or value placeholder
576+ as a string to be included in the query and the value, if any, to be submitted for
577+ processing by mysql API.
578+ :param name: name of attribute to be inserted
579+ :param value: value of attribute to be inserted
580+ """
581+ if ignore_extra_fields and name not in self .heading :
582+ return None
583+ attr = self .heading [name ]
584+ if attr .adapter :
585+ value = attr .adapter .put (value )
586+ if value is None or (attr .numeric and (value == '' or np .isnan (np .float (value )))):
587+ # set default value
588+ placeholder , value = 'DEFAULT' , None
589+ else : # not NULL
590+ placeholder = '%s'
591+ if attr .uuid :
592+ if not isinstance (value , uuid .UUID ):
593+ try :
594+ value = uuid .UUID (value )
595+ except (AttributeError , ValueError ):
596+ raise DataJointError (
597+ 'badly formed UUID value {v} for attribute `{n}`' .format (v = value ,
598+ n = name ))
599+ value = value .bytes
600+ elif attr .is_blob :
601+ value = blob .pack (value )
602+ value = self .external [attr .store ].put (value ).bytes if attr .is_external else value
603+ elif attr .is_attachment :
604+ attachment_path = Path (value )
605+ if attr .is_external :
606+ # value is hash of contents
607+ value = self .external [attr .store ].upload_attachment (attachment_path ).bytes
608+ else :
609+ # value is filename + contents
610+ value = str .encode (attachment_path .name ) + b'\0 ' + attachment_path .read_bytes ()
611+ elif attr .is_filepath :
612+ value = self .external [attr .store ].upload_filepath (value ).bytes
613+ elif attr .numeric :
614+ value = str (int (value ) if isinstance (value , bool ) else value )
615+ return name , placeholder , value
616+
617+ def __make_row_to_insert (self , row , field_list , ignore_extra_fields ):
618+ """
619+ Helper function for insert and update
620+ :param row: A tuple to insert
621+ :return: a dict with fields 'names', 'placeholders', 'values'
622+ """
623+
624+ def check_fields (fields ):
625+ """
626+ Validates that all items in `fields` are valid attributes in the heading
627+ :param fields: field names of a tuple
628+ """
629+ if not field_list :
630+ if not ignore_extra_fields :
631+ for field in fields :
632+ if field not in self .heading :
633+ raise KeyError (u'`{0:s}` is not in the table heading' .format (field ))
634+ elif set (field_list ) != set (fields ).intersection (self .heading .names ):
635+ raise DataJointError ('Attempt to insert rows with different fields' )
636+
637+ if isinstance (row , np .void ): # np.array
638+ check_fields (row .dtype .fields )
639+ attributes = [self .__make_placeholder (name , row [name ], ignore_extra_fields )
640+ for name in self .heading if name in row .dtype .fields ]
641+ elif isinstance (row , collections .abc .Mapping ): # dict-based
642+ check_fields (row )
643+ attributes = [self .__make_placeholder (name , row [name ], ignore_extra_fields )
644+ for name in self .heading if name in row ]
645+ else : # positional
646+ try :
647+ if len (row ) != len (self .heading ):
648+ raise DataJointError (
649+ 'Invalid insert argument. Incorrect number of attributes: '
650+ '{given} given; {expected} expected' .format (
651+ given = len (row ), expected = len (self .heading )))
652+ except TypeError :
653+ raise DataJointError ('Datatype %s cannot be inserted' % type (row ))
654+ else :
655+ attributes = [self .__make_placeholder (name , value , ignore_extra_fields )
656+ for name , value in zip (self .heading , row )]
657+ if ignore_extra_fields :
658+ attributes = [a for a in attributes if a is not None ]
659+
660+ assert len (attributes ), 'Empty tuple'
661+ row_to_insert = dict (zip (('names' , 'placeholders' , 'values' ), zip (* attributes )))
662+ if not field_list :
663+ # first row sets the composition of the field list
664+ field_list .extend (row_to_insert ['names' ])
665+ else :
666+ # reorder attributes in row_to_insert to match field_list
667+ order = list (row_to_insert ['names' ].index (field ) for field in field_list )
668+ row_to_insert ['names' ] = list (row_to_insert ['names' ][i ] for i in order )
669+ row_to_insert ['placeholders' ] = list (row_to_insert ['placeholders' ][i ] for i in order )
670+ row_to_insert ['values' ] = list (row_to_insert ['values' ][i ] for i in order )
671+ return row_to_insert
672+
641673
642674def lookup_class_name (name , context , depth = 3 ):
643675 """
0 commit comments