Skip to content

Commit 4659969

Browse files
committed
Fixes to the sample
During recent testing of the tutorial associated with this repo, I noticed a number of issues: 1. The commands create a bucket named `reproducible-ml` but then the code uses a bucket named `pods-test` 2. The Lambdas ran into permission errors trying to write to and read from the S3 bucket. These changes address those issues.
1 parent 8675302 commit 4659969

File tree

2 files changed

+12
-19
lines changed

2 files changed

+12
-19
lines changed

reproducible-ml/infer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,16 @@
44
import numpy
55
from joblib import load
66

7-
87
def handler(event, context):
98
# download the model and the test set from S3
109
s3_client = boto3.client("s3")
11-
s3_client.download_file(Bucket="pods-test", Key="test-set.npy", Filename="test-set.npy")
12-
s3_client.download_file(Bucket="pods-test", Key="model.joblib", Filename="model.joblib")
10+
s3_client.download_file(Bucket="reproducible-ml", Key="test-set.npy", Filename="/tmp/test-set.npy")
11+
s3_client.download_file(Bucket="reproducible-ml", Key="model.joblib", Filename="/tmp/model.joblib")
1312

14-
with open("test-set.npy", "rb") as f:
13+
with open("/tmp/test-set.npy", "rb") as f:
1514
X_test = numpy.load(f)
1615

17-
clf = load("model.joblib")
16+
clf = load("/tmp/model.joblib")
1817

1918
predicted = clf.predict(X_test)
2019
print("--> prediction result:", predicted)

reproducible-ml/train.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,20 @@ def handler(event, context):
3434
s3_client = boto3.client("s3")
3535
buffer = io.BytesIO()
3636
dump(clf, buffer)
37-
s3_client.put_object(Body=buffer.getvalue(), Bucket="pods-test", Key="model.joblib")
38-
37+
s3_client.put_object(Body=buffer.getvalue(), Bucket="reproducible-ml", Key="model.joblib")
38+
3939
# Save the test-set to the S3 bucket
40-
numpy.save('test-set.npy', X_test)
41-
with open('test-set.npy', 'rb') as f:
42-
s3_client.put_object(Body=f, Bucket="pods-test", Key="test-set.npy")
40+
numpy.save('/tmp/test-set.npy', X_test)
41+
with open('/tmp/test-set.npy', 'rb') as f:
42+
s3_client.put_object(Body=f, Bucket="reproducible-ml", Key="test-set.npy")
4343

4444

4545
def load_digits(*, n_class=10, return_X_y=False, as_frame=False):
4646
# download files from S3
4747
s3_client = boto3.client("s3")
48-
s3_client.download_file(Bucket="pods-test", Key="digits.csv.gz", Filename="digits.csv.gz")
49-
s3_client.download_file(Bucket="pods-test", Key="digits.rst", Filename="digits.rst")
50-
51-
# code below based on sklearn/datasets/_base.py
48+
s3_client.download_file(Bucket="reproducible-ml", Key="digits.csv.gz", Filename="/tmp/digits.csv.gz")
5249

53-
data = numpy.loadtxt('digits.csv.gz', delimiter=',')
54-
with open('digits.rst') as f:
55-
descr = f.read()
50+
data = numpy.loadtxt('/tmp/digits.csv.gz', delimiter=',')
5651
target = data[:, -1].astype(numpy.int, copy=False)
5752
flat_data = data[:, :-1]
5853
images = flat_data.view()
@@ -81,5 +76,4 @@ def load_digits(*, n_class=10, return_X_y=False, as_frame=False):
8176
frame=frame,
8277
feature_names=feature_names,
8378
target_names=numpy.arange(10),
84-
images=images,
85-
DESCR=descr)
79+
images=images)

0 commit comments

Comments
 (0)