Skip to content

Commit 49d326e

Browse files
committed
New AddCSVRow using pandas
1 parent 5620bbe commit 49d326e

File tree

1 file changed

+47
-120
lines changed

1 file changed

+47
-120
lines changed

nipype/algorithms/misc.py

Lines changed: 47 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030

3131
from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File,
3232
InputMultiPath, OutputMultiPath,
33-
BaseInterfaceInputSpec, isdefined)
33+
BaseInterfaceInputSpec, isdefined,
34+
DynamicTraitedSpec )
3435
from ..utils.filemanip import fname_presuffix, split_filename
3536
iflogger = logging.getLogger('interface')
3637

@@ -1147,19 +1148,23 @@ def _list_outputs(self):
11471148
return outputs
11481149

11491150

1150-
class AddCSVRowInputSpec(TraitedSpec):
1151+
class AddCSVRowInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
11511152
in_file = traits.File(mandatory=True, desc='Input comma-separated value (CSV) files')
1152-
cols = traits.Int(desc='Number of columns')
1153-
field_headings = traits.List(traits.Str(), mandatory=True,
1154-
desc='Heading list of available field to be added.')
1155-
new_fields = traits.List( traits.Any(), mandatory=True, desc='List of new values in row', separator=',')
1156-
col_width = traits.Int( 9, mandatory=True, usedefault=True, desc='column width' )
1157-
float_dec = traits.Int( 6, mandatory=True, usedefault=True, desc='decimals' )
1153+
_outputs = traits.Dict( traits.Any, value={}, usedefault=True )
1154+
1155+
def __setattr__(self, key, value):
1156+
if key not in self.copyable_trait_names():
1157+
if not isdefined(value):
1158+
super(AddCSVRowInputSpec, self).__setattr__(key, value)
1159+
self._outputs[key] = value
1160+
else:
1161+
if key in self._outputs:
1162+
self._outputs[key] = value
1163+
super(AddCSVRowInputSpec, self).__setattr__(key, value)
11581164

11591165
class AddCSVRowOutputSpec(TraitedSpec):
11601166
csv_file = File(desc='Output CSV file containing rows ')
11611167

1162-
11631168
class AddCSVRow(BaseInterface):
11641169
"""
11651170
Short interface to add an extra row to a text file
@@ -1169,133 +1174,55 @@ class AddCSVRow(BaseInterface):
11691174
11701175
>>> import nipype.algorithms.misc as misc
11711176
>>> addrow = misc.AddCSVRow()
1172-
>>> addrow.inputs.in_file = 'degree.csv'
1173-
>>> addrow.inputs.field_headings = [ 'id', 'group', 'age', 'degree' ]
1174-
>>> addrow.inputs.new_fields = [ 'S400', 'male', '25', '10.5' ]
1177+
>>> addrow.inputs.in_file = 'scores.csv'
1178+
>>> addrow.inputs.si = 0.74
1179+
>>> addrow.inputs.di = 0.93
1180+
>>> addrow.subject_id = 'S400'
1181+
>>> addrow.inputs.list_of_values = [ 0.4, 0.7, 0.3 ]
11751182
>>> addrow.run() # doctest: +SKIP
11761183
"""
11771184
input_spec = AddCSVRowInputSpec
11781185
output_spec = AddCSVRowOutputSpec
1179-
_hdrstr = None
11801186

1181-
def _run_interface(self, runtime):
1182-
cols = 0
1183-
headings = []
1184-
col_width = self.inputs.col_width
1185-
float_dec = self.inputs.float_dec
1187+
def __init__(self, infields=None, force_run=True, **kwargs):
1188+
super(AddCSVRow, self).__init__(**kwargs)
1189+
undefined_traits = {}
1190+
self._infields = infields
11861191

1187-
if not isdefined( self.inputs.cols ) and not isdefined( self.inputs.field_headings ):
1188-
iflogger.error( 'Either number of cols or field headings is required' )
1192+
if infields:
1193+
for key in infields:
1194+
self.inputs.add_trait( key, traits.Any )
1195+
self.inputs._outputs[key] = Undefined
1196+
undefined_traits[key] = Undefined
1197+
self.inputs.trait_set( trait_change_notify=False, **undefined_traits )
11891198

1190-
if isdefined( self.inputs.cols ) and isdefined( self.inputs.field_headings ):
1191-
if( len( self.inputs.field_headings ) != self.inputs.cols ):
1192-
iflogger.error( 'Number of cols and length of field headings list should match' )
1193-
else:
1194-
cols = self.inputs.cols
1195-
headings = self.inputs.field_headings
1199+
if force_run:
1200+
self._always_run = True
11961201

1197-
if isdefined( self.inputs.cols ) and not isdefined( self.inputs.field_headings ):
1198-
cols = self.inputs.cols
1199-
iflogger.warn( 'No column headers were set.')
1200-
1201-
if not isdefined( self.inputs.cols ) and isdefined( self.inputs.field_headings ):
1202-
cols = len( self.inputs.field_headings )
1203-
headings = self.inputs.field_headings
1204-
1205-
if cols == 0:
1206-
iflogger.error( 'Number of cols and length of field headings must be > 0' )
1207-
1208-
if len( self.inputs.new_fields ) != cols:
1209-
iflogger.warn( 'Wrong length of fields (%d), does not match number of \
1210-
cols (%d)' % (len(self.inputs.new_fields), cols ) )
1211-
cols = len( self.inputs.new_fields )
1202+
def _run_interface(self, runtime):
1203+
import pandas as pd
12121204

1213-
if len(headings)>0:
1214-
argstr = '{:>%d}' % col_width
1215-
hdr = [ argstr.format( '"' + h + '"') for h in self.inputs.field_headings ]
1216-
self._hdrstr = ",".join(hdr) + '\n'
1205+
input_dict = {}
12171206

1207+
for key, val in self.inputs._outputs.items():
1208+
# expand lists to several columns
1209+
if isinstance(val, list):
1210+
for i,v in enumerate(val):
1211+
input_dict['%s_%d' % (key,i)]=v
1212+
else:
1213+
input_dict[key] = val
12181214

1219-
if op.exists( self.inputs.in_file ):
1220-
with open(self.inputs.in_file, 'a+') as in_file:
1221-
lines = in_file.readlines()
1215+
df = pd.DataFrame([input_dict])
12221216

1223-
if len(lines)>0 and lines[0]=='\n':
1224-
lines.pop()
1217+
if op.exists(self.inputs.in_file):
1218+
formerdf = pd.read_csv(self.inputs.in_file, index_col=0)
1219+
df = pd.concat( [formerdf, df], ignore_index=True )
12251220

1226-
if (len(headings)>0) and (len(lines)==0):
1227-
lines.insert(0, self._hdrstr )
1228-
in_file.write( "".join(lines) )
1229-
else:
1230-
with open(self.inputs.in_file, 'w+') as in_file:
1231-
in_file.write( self._hdrstr )
1232-
1233-
1234-
row_data = []
1235-
metadata = dict(separator=lambda t: t is not None)
1236-
for name, spec in sorted(self.inputs.traits(**metadata).items()):
1237-
values = getattr(self.inputs, name)
1238-
for v in values:
1239-
argstr = '{:>%d}' % col_width
1240-
if type(v) is float:
1241-
argstr = '{:>%d.%df}' % ( col_width, float_dec )
1242-
if type(v) is str:
1243-
v = '"' + v + '"'
1244-
row_data.append( argstr.format(v) )
1245-
newrow = ",".join( row_data ) + '\n'
1246-
1247-
with open(self.inputs.in_file, 'r+') as in_file:
1248-
in_file.seek(-2, 2)
1249-
if in_file.read(2) == '\n\n':
1250-
in_file.seek(-1, 1)
1251-
in_file.write( newrow )
1221+
with open(self.inputs.in_file, 'w') as f:
1222+
df.to_csv(f)
12521223

12531224
return runtime
12541225

1255-
1256-
1257-
def _format_row(self, name, trait_spec, value):
1258-
"""A helper function for _run_interface
1259-
"""
1260-
argstr = trait_spec.argstr
1261-
iflogger.debug('%s_%s' % (name, str(value)))
1262-
if trait_spec.is_trait_type(traits.Bool) and "%" not in argstr:
1263-
if value:
1264-
# Boolean options have no format string. Just append options
1265-
# if True.
1266-
return argstr
1267-
else:
1268-
return None
1269-
# traits.Either turns into traits.TraitCompound and does not have any
1270-
# inner_traits
1271-
elif trait_spec.is_trait_type(traits.List) \
1272-
or (trait_spec.is_trait_type(traits.TraitCompound)
1273-
and isinstance(value, list)):
1274-
# This is a bit simple-minded at present, and should be
1275-
# construed as the default. If more sophisticated behavior
1276-
# is needed, it can be accomplished with metadata (e.g.
1277-
# format string for list member str'ification, specifying
1278-
# the separator, etc.)
1279-
1280-
# Depending on whether we stick with traitlets, and whether or
1281-
# not we beef up traitlets.List, we may want to put some
1282-
# type-checking code here as well
1283-
sep = trait_spec.sep
1284-
if sep is None:
1285-
sep = ' '
1286-
if argstr.endswith('...'):
1287-
1288-
# repeatable option
1289-
# --id %d... will expand to
1290-
# --id 1 --id 2 --id 3 etc.,.
1291-
argstr = argstr.replace('...', '')
1292-
return sep.join([argstr % elt for elt in value])
1293-
else:
1294-
return argstr % sep.join(str(elt) for elt in value)
1295-
else:
1296-
# Append options using format string.
1297-
return argstr % value
1298-
12991226
def _list_outputs(self):
13001227
outputs = self.output_spec().get()
13011228
outputs['csv_file'] = self.inputs.in_file

0 commit comments

Comments
 (0)