|
| 1 | +import sys, json, os |
| 2 | +import random |
| 3 | + |
| 4 | +# ===== Definitions ========================================================= |
| 5 | +def expand(Vs, fr, to, soFar): |
| 6 | + soFarNew = [] |
| 7 | + for s in soFar: |
| 8 | + print Vs[fr] |
| 9 | + if (Vs[fr] == None): |
| 10 | + print ("ERROR: The order of json inputs and values must be preserved") |
| 11 | + sys.exit(1) |
| 12 | + for v in Vs[fr]: |
| 13 | + if s == '': |
| 14 | + soFarNew += [str(v)] |
| 15 | + else: |
| 16 | + soFarNew += [s+','+str(v)] |
| 17 | + if fr==to: |
| 18 | + return(soFarNew) |
| 19 | + else: |
| 20 | + return expand(Vs, fr+1, to, soFarNew) |
| 21 | + |
| 22 | +def generate_random(values, n_samples, benchmarkName): |
| 23 | + # select '#samples' random numbers between the range provided in settings.json file |
| 24 | + result = "" |
| 25 | + for s in range(samples[0]): |
| 26 | + if(benchmarkName=="p1b1"): |
| 27 | + # values = {1:epochs, 2: batch_size, 3: N1, 4: NE} |
| 28 | + t_epoch= random.randint(values[1][0], values[1][1]) |
| 29 | + t_batch_size= random.randint(values[2][0], values[2][1]) |
| 30 | + t_N1= random.randint(values[3][0], values[3][1]) |
| 31 | + t_NE= random.randint(values[4][0], values[4][1]) |
| 32 | + result+=str(t_epoch) + ',' + str(t_batch_size) + ',' + str(t_N1) + ',' + str(t_NE) |
| 33 | + elif(benchmarkName=="p1b3"): |
| 34 | + # values = {1:epochs, 2: batch_size, 3: test_cell_split, 4: drop} |
| 35 | + t_epoch= random.randint(values[1][0], values[1][1]) |
| 36 | + t_batch_size= random.randint(values[2][0], values[2][1]) |
| 37 | + t_tcs= random.uniform(values[3][0], values[3][1]) |
| 38 | + t_drop= random.uniform(values[4][0], values[4][1]) |
| 39 | + result+=str(t_epoch) + ',' + str(t_batch_size) + ',' + str(t_tcs) + ',' + str(t_drop) |
| 40 | + elif(benchmarkName=="nt3"): |
| 41 | + # values = {1:epochs, 2: batch_size, 3: classes} |
| 42 | + t_epoch= random.randint(values[1][0], values[1][1]) |
| 43 | + t_batch_size= random.randint(values[2][0], values[2][1]) |
| 44 | + t_classes= random.randint(values[3][0], values[3][1]) |
| 45 | + result+=str(t_epoch) + ',' + str(t_batch_size) + ',' + str(t_classes) |
| 46 | + elif(benchmarkName=="p2b1"): |
| 47 | + # values = {1:epochs, 2: batch_size, 3: molecular_epochs, 4: weight_decay} |
| 48 | + t_epoch= random.randint(values[1][0], values[1][1]) |
| 49 | + t_batch_size= random.randint(values[2][0], values[2][1]) |
| 50 | + t_me= random.randint(values[3][0], values[3][1]) |
| 51 | + t_wd= random.uniform(values[4][0], values[4][1]) |
| 52 | + result+=str(t_epoch) + ',' + str(t_batch_size) + ',' + str(t_me) + ',' + str(t_wd) |
| 53 | + elif(benchmarkName=="p3b1"): |
| 54 | + # values = {1:epochs, 2: batch_size, 3: shared_nnet_spec, 4: n_fold} |
| 55 | + t_epoch= random.randint(values[1][0], values[1][1]) |
| 56 | + t_batch_size= random.randint(values[2][0], values[2][1]) |
| 57 | + t_sns= random.randint(values[3][0], values[3][1]) |
| 58 | + t_nf= random.randint(values[4][0], values[4][1]) |
| 59 | + result+=str(t_epoch) + ',' + str(t_batch_size) + ',' + str(t_sns) + ',' + str(t_nf) |
| 60 | + else: |
| 61 | + print('ERROR: Tried all possible benchmarks, Invalid benchmark name or json file') |
| 62 | + sys.exit(1) |
| 63 | + # Populate the result string for writing sweep-parameters file |
| 64 | + if(s < (samples[0]-1)): |
| 65 | + result+=":" |
| 66 | + return result |
| 67 | + |
| 68 | +# ===== Main program ======================================================== |
| 69 | +if (len(sys.argv) < 3): |
| 70 | + print('requires arg1=settingsFilename and arg2=paramsFilename') |
| 71 | + sys.exit(1) |
| 72 | + |
| 73 | +settingsFilename = sys.argv[1] |
| 74 | +paramsFilename = sys.argv[2] |
| 75 | +benchmarkName = sys.argv[3] |
| 76 | +searchType = sys.argv[4] |
| 77 | + |
| 78 | +#Trying to open the settings file |
| 79 | +print("Reading settings: %s" % settingsFilename) |
| 80 | +try: |
| 81 | + with open(settingsFilename) as fp: |
| 82 | + settings = json.load(fp) |
| 83 | +except IOError as e: |
| 84 | + print("Could not open: %s" % settingsFilename) |
| 85 | + print("PWD is: '%s'" % os.getcwd()) |
| 86 | + sys.exit(1) |
| 87 | + |
| 88 | +# Read in the variables from json file |
| 89 | +# Register new variables for any benchmark here |
| 90 | +#Common variables |
| 91 | +epochs = settings.get('parameters').get('epochs') |
| 92 | +batch_size = settings.get('parameters').get('batch_size') |
| 93 | +# P1B1 |
| 94 | +N1 = settings.get('parameters').get('N1') |
| 95 | +NE = settings.get('parameters').get('NE') |
| 96 | +#NT3 |
| 97 | +classes = settings.get('parameters').get('classes') |
| 98 | +#P2B1 |
| 99 | +molecular_epochs = settings.get('parameters').get('molecular_epochs') |
| 100 | +weight_decay = settings.get('parameters').get('weight_decay') |
| 101 | +#P3B1 |
| 102 | +shared_nnet_spec = settings.get('parameters').get('shared_nnet_spec') |
| 103 | +n_fold = settings.get('parameters').get('n_fold') |
| 104 | +#P1B3 |
| 105 | +test_cell_split = settings.get('parameters').get('test_cell_split') |
| 106 | +drop = settings.get('parameters').get('drop') |
| 107 | + |
| 108 | +# For random scheme determine number of samples |
| 109 | +samples = settings.get('samples', {}).get('num', None) |
| 110 | + |
| 111 | + |
| 112 | +# Make values for computing grid sweep parameters |
| 113 | +values = {} |
| 114 | +if(benchmarkName=="p1b1"): |
| 115 | + values = {1:epochs, 2: batch_size, 3: N1, 4: NE} |
| 116 | + print values |
| 117 | +elif(benchmarkName=="p1b3"): |
| 118 | + values = {1:epochs, 2: batch_size, 3: test_cell_split, 4: drop} |
| 119 | + print values |
| 120 | +elif(benchmarkName=="nt3"): |
| 121 | + values = {1:epochs, 2: batch_size, 3: classes} |
| 122 | + print values |
| 123 | +elif(benchmarkName=="p2b1"): |
| 124 | + values = {1:epochs, 2: batch_size, 3: molecular_epochs, 4: weight_decay} |
| 125 | + print values |
| 126 | +elif(benchmarkName=="p3b1"): |
| 127 | + values = {1:epochs, 2: batch_size, 3: shared_nnet_spec, 4: n_fold} |
| 128 | + print values |
| 129 | +else: |
| 130 | + print('ERROR: Tried all possible benchmarks, Invalid benchmark name or json file') |
| 131 | + sys.exit(1) |
| 132 | + |
| 133 | +result = {} |
| 134 | +if(searchType == "grid"): |
| 135 | + results = expand(values, 1, len(values), ['']) |
| 136 | + result = ':'.join(results) |
| 137 | +elif(searchType =="random"): |
| 138 | + if(samples == None): |
| 139 | + print ("ERROR: Provide number of samples in json file") |
| 140 | + sys.exit(1) |
| 141 | + result = generate_random(values, samples, benchmarkName) |
| 142 | +else: |
| 143 | + print ("ERROR: Invalid search type, specify either - grid or random") |
| 144 | + sys.exit(1) |
| 145 | + |
| 146 | + |
| 147 | +with open(paramsFilename, 'w') as the_file: |
| 148 | + the_file.write(result) |
| 149 | + |
0 commit comments