Skip to content

Commit 11e9efe

Browse files
author
aditya tewari
committed
args management in input call
1 parent 8cf5e1c commit 11e9efe

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

examples/lienet/task.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
#!/usr/bin/env python3
22
import argparse
33
import os
4+
import sys
45
import tensorflow as tf
56
import numpy as np
67

7-
from . import model
8-
from shared import utils
8+
import model
9+
from examples.shared import utils
910

1011
DATA_URL = "https://data.vision.ee.ethz.ch/zzhiwu/ManifoldNetData/LieData/G3D_Lie_data.zip"
1112
DATA_FOLDER = "lie20_half_inter1"
1213
G3D_CLASSES = 20
1314
VAL_SPLIT = 0.2
1415

1516

16-
def get_args():
17-
parser = argparse.ArgumentParser()
17+
def get_args(args):
18+
parser = argparse.ArgumentParser(args)
1819
parser.add_argument(
1920
'--job-dir', type=str, required=True, help='checkpoint dir'
2021
)
@@ -46,7 +47,7 @@ def get_args():
4647
return parser.parse_args()
4748

4849

49-
def prepare_data():
50+
def prepare_data(args):
5051
features, labels = utils.load_matlab_data("fea", args.data_dir, DATA_FOLDER)
5152
features = np.array([np.stack(example) for example in features.squeeze()])
5253
# reshape to [batch_size, spatial_dim, temp_dim, num_rows, num_cols]
@@ -61,7 +62,7 @@ def prepare_data():
6162

6263
def train_and_evaluate(args):
6364
utils.download_data(args.data_dir, DATA_URL, unpack=True)
64-
train, val = prepare_data()
65+
train, val = prepare_data(args)
6566

6667
train_dataset = (
6768
tf.data.Dataset.from_tensor_slices(train)
@@ -90,4 +91,5 @@ def train_and_evaluate(args):
9091

9192
if __name__ == "__main__":
9293
tf.get_logger().setLevel("INFO")
93-
train_and_evaluate(get_args())
94+
argv = sys.argv[1:]
95+
train_and_evaluate(get_args(argv))

0 commit comments

Comments
 (0)