Skip to content

Commit 341dfb6

Browse files
committed
adding lab1 tests
1 parent b3a32d4 commit 341dfb6

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

mitdeeplearning/lab1.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ def save_song_to_abc(song, filename="tmp"):
2929

3030
def abc2wav(abc_file):
3131
path_to_tool = os.path.join(cwd, 'bin', 'abc2wav')
32-
print("path_to_tool", path_to_tool)
33-
3432
cmd = "{} {}".format(path_to_tool, abc_file)
3533
return os.system(cmd)
3634

@@ -55,3 +53,26 @@ def play_generated_song(generated_text):
5553
play_song(song)
5654
print("None of the songs were valid, try training longer to improve \
5755
syntax.")
56+
57+
def test_batch_func_types(func, args):
58+
ret = func(*args)
59+
assert len(ret) == 2, "[FAIL] get_batch must return two arguments (input and label)"
60+
assert type(ret[0]) == np.ndarray, "[FAIL] test_batch_func_types: x is not np.array"
61+
assert type(ret[1]) == np.ndarray, "[FAIL] test_batch_func_types: y is not np.array"
62+
print("[PASS] test_batch_func_types")
63+
return True
64+
65+
def test_batch_func_shapes(func, args):
66+
dataset, seq_length, batch_size = args
67+
x, y = func(*args)
68+
correct = (batch_size, seq_length)
69+
assert x.shape == correct, "[FAIL] test_batch_func_shapes: x {} is not correct shape {}".format(x.shape, correct)
70+
assert y.shape == correct, "[FAIL] test_batch_func_shapes: y {} is not correct shape {}".format(y.shape, correct)
71+
print("[PASS] test_batch_func_shapes")
72+
return True
73+
74+
def test_batch_func_next_step(func, args):
75+
x, y = func(*args)
76+
assert (x[:,1:] == y[:,:-1]).all(), "[FAIL] test_batch_func_next_step: x_{t} must equal y_{t-1} for all t"
77+
print("[PASS] test_batch_func_next_step")
78+
return True

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ def get_dist(pkgname):
2222
setup(
2323
name = 'mitdeeplearning', # How you named your package folder (MyLib)
2424
packages = ['mitdeeplearning'], # Chose the same as "name"
25-
version = '0.3.5', # Start with a small number and increase it with every change you make
25+
version = '0.3.6', # Start with a small number and increase it with every change you make
2626
license='MIT', # Chose a license from here: https://help.github.com/articles/licensing-a-repository
2727
description = 'Official software labs for MIT Introduction to Deep Learning (http://introtodeeplearning.com)', # Give a short description about your library
2828
author = 'Alexander Amini', # Type in your name
2929
author_email = '[email protected]', # Type in your E-Mail
3030
url = 'http://introtodeeplearning.com', # Provide either the link to your github or to your website
31-
download_url = 'https://github.com/aamini/introtodeeplearning_labs/archive/v0.3.5.tar.gz', # I explain this later on
31+
download_url = 'https://github.com/aamini/introtodeeplearning_labs/archive/v0.3.6.tar.gz', # I explain this later on
3232
keywords = ['deep learning', 'neural networks', 'tensorflow', 'introduction'], # Keywords that define your package best
3333
install_requires=install_deps,
3434
classifiers=[

0 commit comments

Comments
 (0)