1
1
#!/usr/bin/env python3
2
2
import argparse
3
3
import os
4
+ import sys
4
5
import tensorflow as tf
5
6
import numpy as np
6
7
7
- from . import model
8
- from shared import utils
8
+ import model
9
+ from examples . shared import utils
9
10
10
11
DATA_URL = "https://data.vision.ee.ethz.ch/zzhiwu/ManifoldNetData/LieData/G3D_Lie_data.zip"
11
12
DATA_FOLDER = "lie20_half_inter1"
12
13
G3D_CLASSES = 20
13
14
VAL_SPLIT = 0.2
14
15
15
16
16
- def get_args ():
17
- parser = argparse .ArgumentParser ()
17
+ def get_args (args ):
18
+ parser = argparse .ArgumentParser (args )
18
19
parser .add_argument (
19
20
'--job-dir' , type = str , required = True , help = 'checkpoint dir'
20
21
)
@@ -46,7 +47,7 @@ def get_args():
46
47
return parser .parse_args ()
47
48
48
49
49
- def prepare_data ():
50
+ def prepare_data (args ):
50
51
features , labels = utils .load_matlab_data ("fea" , args .data_dir , DATA_FOLDER )
51
52
features = np .array ([np .stack (example ) for example in features .squeeze ()])
52
53
# reshape to [batch_size, spatial_dim, temp_dim, num_rows, num_cols]
@@ -61,7 +62,7 @@ def prepare_data():
61
62
62
63
def train_and_evaluate (args ):
63
64
utils .download_data (args .data_dir , DATA_URL , unpack = True )
64
- train , val = prepare_data ()
65
+ train , val = prepare_data (args )
65
66
66
67
train_dataset = (
67
68
tf .data .Dataset .from_tensor_slices (train )
@@ -90,4 +91,5 @@ def train_and_evaluate(args):
90
91
91
92
if __name__ == "__main__" :
92
93
tf .get_logger ().setLevel ("INFO" )
93
- train_and_evaluate (get_args ())
94
+ argv = sys .argv [1 :]
95
+ train_and_evaluate (get_args (argv ))
0 commit comments