Skip to content

Commit 89d7e9c

Browse files
committed
Added fakes3 integration with datasink and started adding a local_copy flag to the output generation logic
1 parent d25afb5 commit 89d7e9c

File tree

2 files changed

+166
-36
lines changed

2 files changed

+166
-36
lines changed

nipype/interfaces/io.py

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ class DataSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
205205
'access')
206206
encrypt_bucket_keys = traits.Bool(desc='Flag indicating whether to use S3 '\
207207
'server-side AES-256 encryption')
208+
# Set this if user wishes to override the bucket with their own
209+
bucket = traits.Generic(mandatory=False,
210+
desc='Boto3 S3 bucket for manual override of bucket')
211+
# Set this if user wishes to have local copy of files as well
212+
local_dir = traits.Str(desc='Copy files locally as well as to S3 bucket')
208213

209214
# Set call-able inputs attributes
210215
def __setattr__(self, key, value):
@@ -385,7 +390,6 @@ def _check_s3_base_dir(self):
385390

386391
# Init variables
387392
s3_str = 's3://'
388-
sep = os.path.sep
389393
base_directory = self.inputs.base_directory
390394

391395
# Explicitly lower-case the "s3"
@@ -396,11 +400,16 @@ def _check_s3_base_dir(self):
396400

397401
# Check if 's3://' in base dir
398402
if base_directory.startswith(s3_str):
403+
# Attempt to access bucket
399404
try:
400405
# Expects bucket name to be 's3://bucket_name/base_dir/..'
401-
bucket_name = base_directory.split(s3_str)[1].split(sep)[0]
406+
bucket_name = base_directory.split(s3_str)[1].split('/')[0]
402407
# Get the actual bucket object
403-
self.bucket = self._fetch_bucket(bucket_name)
408+
if self.inputs.bucket:
409+
self.bucket = self.inputs.bucket
410+
else:
411+
self.bucket = self._fetch_bucket(bucket_name)
412+
# Report error in case of exception
404413
except Exception as exc:
405414
err_msg = 'Unable to access S3 bucket. Error:\n%s. Exiting...'\
406415
% exc
@@ -566,7 +575,7 @@ def _upload_to_s3(self, src, dst):
566575
bucket = self.bucket
567576
iflogger = logging.getLogger('interface')
568577
s3_str = 's3://'
569-
s3_prefix = os.path.join(s3_str, bucket.name)
578+
s3_prefix = s3_str + bucket.name
570579

571580
# Explicitly lower-case the "s3"
572581
if dst.lower().startswith(s3_str):
@@ -629,41 +638,53 @@ def _list_outputs(self):
629638
iflogger = logging.getLogger('interface')
630639
outputs = self.output_spec().get()
631640
out_files = []
632-
outdir = self.inputs.base_directory
641+
# Use hardlink
633642
use_hardlink = str2bool(config.get('execution', 'try_hard_link_datasink'))
634643

635-
# If base directory isn't given, assume current directory
636-
if not isdefined(outdir):
637-
outdir = '.'
644+
# Set local output directory if specified
645+
if isdefined(self.inputs.local_copy):
646+
outdir = self.inputs.local_copy
647+
else:
648+
outdir = self.inputs.base_directory
649+
# If base directory isn't given, assume current directory
650+
if not isdefined(outdir):
651+
outdir = '.'
638652

639-
# Check if base directory reflects S3-bucket upload
653+
# Check if base directory reflects S3 bucket upload
640654
try:
641655
s3_flag = self._check_s3_base_dir()
656+
s3dir = self.inputs.base_directory
657+
if isdefined(self.inputs.container):
658+
s3dir = os.path.join(s3dir, self.inputs.container)
642659
# If encountering an exception during bucket access, set output
643660
# base directory to a local folder
644661
except Exception as exc:
645-
local_out_exception = os.path.join(os.path.expanduser('~'),
646-
'data_output')
662+
if not isdefined(self.inputs.local_copy):
663+
local_out_exception = os.path.join(os.path.expanduser('~'),
664+
's3_datasink_' + self.bucket.name)
665+
outdir = local_out_exception
666+
else:
667+
outdir = self.inputs.local_copy
668+
# Log local copying directory
647669
iflogger.info('Access to S3 failed! Storing outputs locally at: '\
648-
'%s\nError: %s' %(local_out_exception, exc))
649-
self.inputs.base_directory = local_out_exception
650-
651-
# If not accessing S3, just set outdir to local absolute path
652-
if not s3_flag:
653-
outdir = os.path.abspath(outdir)
670+
'%s\nError: %s' %(outdir, exc))
654671

655672
# If container input is given, append that to outdir
656673
if isdefined(self.inputs.container):
657674
outdir = os.path.join(outdir, self.inputs.container)
658-
# Create the directory if it doesn't exist
659-
if not os.path.exists(outdir):
660-
try:
661-
os.makedirs(outdir)
662-
except OSError, inst:
663-
if 'File exists' in inst:
664-
pass
665-
else:
666-
raise(inst)
675+
676+
# If doing a localy output
677+
if not outdir.lower().startswith('s3://'):
678+
outdir = os.path.abspath(outdir)
679+
# Create the directory if it doesn't exist
680+
if not os.path.exists(outdir):
681+
try:
682+
os.makedirs(outdir)
683+
except OSError, inst:
684+
if 'File exists' in inst:
685+
pass
686+
else:
687+
raise(inst)
667688

668689
# Iterate through outputs attributes {key : path(s)}
669690
for key, files in self.inputs._outputs.items():
@@ -672,10 +693,14 @@ def _list_outputs(self):
672693
iflogger.debug("key: %s files: %s" % (key, str(files)))
673694
files = filename_to_list(files)
674695
tempoutdir = outdir
696+
if s3_flag:
697+
s3tempoutdir = s3dir
675698
for d in key.split('.'):
676699
if d[0] == '@':
677700
continue
678701
tempoutdir = os.path.join(tempoutdir, d)
702+
if s3_flag:
703+
s3tempoutdir = os.path.join(s3tempoutdir, d)
679704

680705
# flattening list
681706
if isinstance(files, list):
@@ -690,25 +715,26 @@ def _list_outputs(self):
690715
src = os.path.join(src, '')
691716
dst = self._get_dst(src)
692717
dst = os.path.join(tempoutdir, dst)
718+
s3dst = os.path.join(s3tempoutdir, dst)
693719
dst = self._substitute(dst)
694720
path, _ = os.path.split(dst)
695721

696-
# Create output directory if it doesnt exist
697-
if not os.path.exists(path):
698-
try:
699-
os.makedirs(path)
700-
except OSError, inst:
701-
if 'File exists' in inst:
702-
pass
703-
else:
704-
raise(inst)
705-
706722
# If we're uploading to S3
707723
if s3_flag:
724+
dst = dst.replace(outdir, self.inputs.base_directory)
708725
self._upload_to_s3(src, dst)
709726
out_files.append(dst)
710727
# Otherwise, copy locally src -> dst
711728
else:
729+
# Create output directory if it doesnt exist
730+
if not os.path.exists(path):
731+
try:
732+
os.makedirs(path)
733+
except OSError, inst:
734+
if 'File exists' in inst:
735+
pass
736+
else:
737+
raise(inst)
712738
# If src is a file, copy it to dst
713739
if os.path.isfile(src):
714740
iflogger.debug('copyfile: %s %s' % (src, dst))

nipype/interfaces/tests/test_io.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,21 @@
1313
import nipype.interfaces.io as nio
1414
from nipype.interfaces.base import Undefined
1515

16+
# Check for boto
1617
noboto = False
1718
try:
1819
import boto
1920
from boto.s3.connection import S3Connection, OrdinaryCallingFormat
2021
except:
2122
noboto = True
2223

24+
# Check for boto3
25+
noboto3 = False
26+
try:
27+
import boto3
28+
from botocore.utils import fix_s3_host
29+
except:
30+
noboto3 = True
2331

2432
def test_datagrabber():
2533
dg = nio.DataGrabber()
@@ -155,6 +163,102 @@ def test_datasink():
155163
ds = nio.DataSink(infields=['test'])
156164
yield assert_true, 'test' in ds.inputs.copyable_trait_names()
157165

166+
# Function to check for fakes3
167+
def _check_for_fakes3():
168+
'''
169+
Function used internally to check for fakes3 installation
170+
'''
171+
172+
# Import packages
173+
import subprocess
174+
175+
# Init variables
176+
fakes3_found = False
177+
178+
# Check for fakes3
179+
try:
180+
ret_code = subprocess.check_call(['which', 'fakes3'])
181+
if ret_code == 0:
182+
fakes3_found = True
183+
except subprocess.CalledProcessError as exc:
184+
print 'fakes3 not found, install via \'gem install fakes3\', skipping test...'
185+
except:
186+
print 'Unable to check for fakes3 installation, skipping test...'
187+
188+
# Return if found
189+
return fakes3_found
190+
191+
@skipif(noboto3)
192+
# Test datasink writes to s3 properly
193+
def test_datasink_to_s3():
194+
'''
195+
This function tests to see if the S3 functionality of a DataSink
196+
works properly
197+
'''
198+
199+
# Import packages
200+
import hashlib
201+
import tempfile
202+
203+
# Init variables
204+
ds = nio.DataSink()
205+
bucket_name = 'test'
206+
container = 'outputs'
207+
attr_folder = 'text_file'
208+
output_dir = 's3://' + bucket_name
209+
# Local temporary filepaths for testing
210+
fakes3_dir = tempfile.mkdtemp()
211+
input_dir = tempfile.mkdtemp()
212+
input_path = os.path.join(input_dir, 'datasink_test_s3.txt')
213+
214+
# Check for fakes3
215+
fakes3_found = _check_for_fakes3()
216+
if not fakes3_found:
217+
return
218+
219+
# Start up fake-S3 server
220+
proc = Popen(['fakes3', '-r', fakes3_dir, '-p', '4567'], stdout=open(os.devnull, 'wb'))
221+
222+
# Init boto3 s3 resource to talk with fakes3
223+
resource = boto3.resource(aws_access_key_id='mykey',
224+
aws_secret_access_key='mysecret',
225+
service_name='s3',
226+
endpoint_url='http://localhost:4567',
227+
use_ssl=False)
228+
resource.meta.client.meta.events.unregister('before-sign.s3', fix_s3_host)
229+
230+
# Create bucket
231+
bucket = resource.create_bucket(Bucket=bucket_name)
232+
233+
# Create input file
234+
with open(input_path, 'wb') as f:
235+
f.write('ABCD1234')
236+
237+
# Prep datasink
238+
ds.inputs.base_directory = output_dir
239+
ds.inputs.container = container
240+
ds.inputs.bucket = bucket
241+
setattr(ds.inputs, attr_folder, input_path)
242+
243+
# Run datasink
244+
ds.run()
245+
246+
# Get MD5sums and compare
247+
key = '/'.join([container, attr_folder, os.path.basename(input_path)])
248+
obj = bucket.Object(key=key)
249+
dst_md5 = obj.e_tag.replace('"', '')
250+
src_md5 = hashlib.md5(open(input_path, 'rb').read()).hexdigest()
251+
252+
# Make sure md5sums match
253+
yield assert_equal, src_md5, dst_md5
254+
255+
# Kill fakes3
256+
proc.kill()
257+
258+
# Delete fakes3 folder and input file
259+
shutil.rmtree(fakes3_dir)
260+
shutil.rmtree(input_dir)
261+
158262
@skipif(noboto)
159263
def test_s3datasink():
160264
ds = nio.S3DataSink()

0 commit comments

Comments
 (0)