Skip to content

Commit 762dd3e

Browse files
authored
Update p3b3_baseline_keras2.py
1 parent 7c4a1f2 commit 762dd3e

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

Pilot3/P3B3/p3b3_baseline_keras2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ def run(gParameters, fpath):
156156
test_x = np.load( fpath + '/test_X.npy' )
157157
test_y = np.load( fpath + '/test_Y.npy' )
158158

159+
160+
for task in range( len( train_y[ 0, : ] ) ):
161+
cat = np.unique( train_y[ :, task ] )
162+
train_y[ :, task ] = [ np.where( cat == x )[ 0 ][ 0 ] for x in train_y[ :, task ] ]
163+
test_y[ :, task ] = [ np.where( cat == x )[ 0 ][ 0 ] for x in test_y[ :, task ] ]
159164

160165
run_filter_sizes = []
161166
run_num_filters = []

0 commit comments

Comments
 (0)