Skip to content

Commit a70c81e

Browse files
committed
Finished adding local_copy logic and passed all unit tests
1 parent 613d8cb commit a70c81e

File tree

2 files changed

+98
-32
lines changed

2 files changed

+98
-32
lines changed

nipype/interfaces/io.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class DataSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
209209
bucket = traits.Generic(mandatory=False,
210210
desc='Boto3 S3 bucket for manual override of bucket')
211211
# Set this if user wishes to have local copy of files as well
212-
local_dir = traits.Str(desc='Copy files locally as well as to S3 bucket')
212+
local_copy = traits.Str(desc='Copy files locally as well as to S3 bucket')
213213

214214
# Set call-able inputs attributes
215215
def __setattr__(self, key, value):
@@ -392,6 +392,10 @@ def _check_s3_base_dir(self):
392392
s3_str = 's3://'
393393
base_directory = self.inputs.base_directory
394394

395+
if not isdefined(base_directory):
396+
s3_flag = False
397+
return s3_flag
398+
395399
# Explicitly lower-case the "s3"
396400
if base_directory.lower().startswith(s3_str):
397401
base_dir_sp = base_directory.split('/')
@@ -616,7 +620,7 @@ def _upload_to_s3(self, src, dst):
616620
else:
617621
iflogger.info('Overwriting previous S3 file...')
618622

619-
except ClientError as exc:
623+
except ClientError:
620624
iflogger.info('New file to S3')
621625

622626
# Copy file up to S3 (either encrypted or not)
@@ -653,18 +657,21 @@ def _list_outputs(self):
653657
# Check if base directory reflects S3 bucket upload
654658
try:
655659
s3_flag = self._check_s3_base_dir()
656-
s3dir = self.inputs.base_directory
657-
if isdefined(self.inputs.container):
658-
s3dir = os.path.join(s3dir, self.inputs.container)
660+
if s3_flag:
661+
s3dir = self.inputs.base_directory
662+
if isdefined(self.inputs.container):
663+
s3dir = os.path.join(s3dir, self.inputs.container)
664+
else:
665+
s3dir = '<N/A>'
659666
# If encountering an exception during bucket access, set output
660667
# base directory to a local folder
661668
except Exception as exc:
669+
s3dir = '<N/A>'
670+
s3_flag = False
662671
if not isdefined(self.inputs.local_copy):
663672
local_out_exception = os.path.join(os.path.expanduser('~'),
664673
's3_datasink_' + self.bucket.name)
665674
outdir = local_out_exception
666-
else:
667-
outdir = self.inputs.local_copy
668675
# Log local copying directory
669676
iflogger.info('Access to S3 failed! Storing outputs locally at: '\
670677
'%s\nError: %s' %(outdir, exc))
@@ -673,8 +680,8 @@ def _list_outputs(self):
673680
if isdefined(self.inputs.container):
674681
outdir = os.path.join(outdir, self.inputs.container)
675682

676-
# If doing a localy output
677-
if not outdir.lower().startswith('s3://'):
683+
# If sinking to local folder
684+
if outdir != s3dir:
678685
outdir = os.path.abspath(outdir)
679686
# Create the directory if it doesn't exist
680687
if not os.path.exists(outdir):
@@ -714,18 +721,19 @@ def _list_outputs(self):
714721
if not os.path.isfile(src):
715722
src = os.path.join(src, '')
716723
dst = self._get_dst(src)
724+
if s3_flag:
725+
s3dst = os.path.join(s3tempoutdir, dst)
726+
s3dst = self._substitute(s3dst)
717727
dst = os.path.join(tempoutdir, dst)
718-
s3dst = os.path.join(s3tempoutdir, dst)
719728
dst = self._substitute(dst)
720729
path, _ = os.path.split(dst)
721730

722731
# If we're uploading to S3
723732
if s3_flag:
724-
dst = dst.replace(outdir, self.inputs.base_directory)
725-
self._upload_to_s3(src, dst)
726-
out_files.append(dst)
733+
self._upload_to_s3(src, s3dst)
734+
out_files.append(s3dst)
727735
# Otherwise, copy locally src -> dst
728-
else:
736+
if not s3_flag or isdefined(self.inputs.local_copy):
729737
# Create output directory if it doesnt exist
730738
if not os.path.exists(path):
731739
try:
@@ -787,6 +795,8 @@ class S3DataSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
787795
_outputs = traits.Dict(traits.Str, value={}, usedefault=True)
788796
remove_dest_dir = traits.Bool(False, usedefault=True,
789797
desc='remove dest directory when copying dirs')
798+
# Set this if user wishes to have local copy of files as well
799+
local_copy = traits.Str(desc='Copy files locally as well as to S3 bucket')
790800

791801
def __setattr__(self, key, value):
792802
if key not in self.copyable_trait_names():

nipype/interfaces/tests/test_io.py

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _check_for_fakes3():
177177

178178
# Check for fakes3
179179
try:
180-
ret_code = subprocess.check_call(['which', 'fakes3'])
180+
ret_code = subprocess.check_call(['which', 'fakes3'], stdout=open(os.devnull, 'wb'))
181181
if ret_code == 0:
182182
fakes3_found = True
183183
except subprocess.CalledProcessError as exc:
@@ -188,7 +188,29 @@ def _check_for_fakes3():
188188
# Return if found
189189
return fakes3_found
190190

191-
@skipif(noboto3)
191+
def _make_dummy_input():
192+
'''
193+
'''
194+
195+
# Import packages
196+
import tempfile
197+
198+
# Init variables
199+
input_dir = tempfile.mkdtemp()
200+
input_path = os.path.join(input_dir, 'datasink_test_s3.txt')
201+
202+
# Create input file
203+
with open(input_path, 'wb') as f:
204+
f.write('ABCD1234')
205+
206+
# Return path
207+
return input_path
208+
209+
# Check for fakes3
210+
fakes3 = _check_for_fakes3()
211+
212+
213+
@skipif(noboto3 or not fakes3)
192214
# Test datasink writes to s3 properly
193215
def test_datasink_to_s3():
194216
'''
@@ -208,13 +230,7 @@ def test_datasink_to_s3():
208230
output_dir = 's3://' + bucket_name
209231
# Local temporary filepaths for testing
210232
fakes3_dir = tempfile.mkdtemp()
211-
input_dir = tempfile.mkdtemp()
212-
input_path = os.path.join(input_dir, 'datasink_test_s3.txt')
213-
214-
# Check for fakes3
215-
fakes3_found = _check_for_fakes3()
216-
if not fakes3_found:
217-
return
233+
input_path = _make_dummy_input()
218234

219235
# Start up fake-S3 server
220236
proc = Popen(['fakes3', '-r', fakes3_dir, '-p', '4567'], stdout=open(os.devnull, 'wb'))
@@ -230,10 +246,6 @@ def test_datasink_to_s3():
230246
# Create bucket
231247
bucket = resource.create_bucket(Bucket=bucket_name)
232248

233-
# Create input file
234-
with open(input_path, 'wb') as f:
235-
f.write('ABCD1234')
236-
237249
# Prep datasink
238250
ds.inputs.base_directory = output_dir
239251
ds.inputs.container = container
@@ -249,15 +261,59 @@ def test_datasink_to_s3():
249261
dst_md5 = obj.e_tag.replace('"', '')
250262
src_md5 = hashlib.md5(open(input_path, 'rb').read()).hexdigest()
251263

252-
# Make sure md5sums match
253-
yield assert_equal, src_md5, dst_md5
254-
255264
# Kill fakes3
256265
proc.kill()
257266

258267
# Delete fakes3 folder and input file
259268
shutil.rmtree(fakes3_dir)
260-
shutil.rmtree(input_dir)
269+
shutil.rmtree(os.path.dirname(input_path))
270+
271+
# Make sure md5sums match
272+
yield assert_equal, src_md5, dst_md5
273+
274+
# Test the local copy attribute
275+
def test_datasink_localcopy():
276+
'''
277+
Function to validate DataSink will make local copy via local_copy
278+
attribute
279+
'''
280+
281+
# Import packages
282+
import hashlib
283+
import tempfile
284+
285+
# Init variables
286+
local_dir = tempfile.mkdtemp()
287+
container = 'outputs'
288+
attr_folder = 'text_file'
289+
290+
# Make dummy input file and datasink
291+
input_path = _make_dummy_input()
292+
ds = nio.DataSink()
293+
294+
# Set up datasink
295+
ds.inputs.container = container
296+
ds.inputs.local_copy = local_dir
297+
setattr(ds.inputs, attr_folder, input_path)
298+
299+
# Expected local copy path
300+
local_copy = os.path.join(local_dir, container, attr_folder,
301+
os.path.basename(input_path))
302+
303+
# Run the datasink
304+
ds.run()
305+
306+
# Check md5sums of both
307+
src_md5 = hashlib.md5(open(input_path, 'rb').read()).hexdigest()
308+
dst_md5 = hashlib.md5(open(local_copy, 'rb').read()).hexdigest()
309+
310+
# Delete temp diretories
311+
shutil.rmtree(os.path.dirname(input_path))
312+
shutil.rmtree(local_dir)
313+
314+
# Perform test
315+
yield assert_equal, src_md5, dst_md5
316+
261317

262318
@skipif(noboto)
263319
def test_s3datasink():
@@ -300,7 +356,7 @@ def test_datasink_substitutions():
300356
shutil.rmtree(indir)
301357
shutil.rmtree(outdir)
302358

303-
@skipif(noboto)
359+
@skipif(noboto or not fakes3)
304360
def test_s3datasink_substitutions():
305361
indir = mkdtemp(prefix='-Tmp-nipype_ds_subs_in')
306362
outdir = mkdtemp(prefix='-Tmp-nipype_ds_subs_out')

0 commit comments

Comments
 (0)