Skip to content

Commit e790951

Browse files
committed
added support for benchmark data sets
1 parent 1847d5f commit e790951

File tree

1 file changed

+84
-37
lines changed

1 file changed

+84
-37
lines changed

examples/ADRP/adrp.py

Lines changed: 84 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
additional_definitions = [
2020
{"name": "latent_dim", "action": "store", "type": int, "help": "latent dimensions"},
21+
{"name": "benchmark_data", "action": "store", "type": candle.str2bool, "default": False, "help": "Use prepared benchmark data"},
2122
{
2223
"name": "residual",
2324
"type": candle.str2bool,
@@ -230,43 +231,89 @@ def get_model(params):
230231

231232

232233
def load_data(params, seed):
233-
header_url = params["header_url"]
234-
dh_dict, th_list = load_headers(
235-
"descriptor_headers.csv", "training_headers.csv", header_url
236-
)
237-
offset = 6 # descriptor starts at index 6
238-
desc_col_idx = [dh_dict[key] + offset for key in th_list]
239-
240-
url = params["data_url"]
241-
file_train = (
242-
"ml." + params["base_name"] + ".Orderable_zinc_db_enaHLL.sorted.4col.dd.parquet"
243-
)
244-
# file_train = params["train_data"]
245-
train_file = candle.get_file(file_train, url + file_train, cache_subdir="Pilot1")
246-
# df = (pd.read_csv(data_path,skiprows=1).values).astype('float32')
247-
print("Loading data...")
248-
df = pd.read_parquet(train_file)
249-
print("done")
250-
251-
# df_y = df[:,0].astype('float32')
252-
df_y = df["reg"].astype("float32")
253-
# df_x = df[:, 1:PL].astype(np.float32)
254-
df_x = df.iloc[:, desc_col_idx].astype(np.float32)
255-
256-
bins = np.arange(0, 20)
257-
histogram, bin_edges = np.histogram(df_y, bins=bins, density=False)
258-
print("Histogram of samples (bins, counts)")
259-
print(bin_edges)
260-
print(histogram)
261-
262-
# scaler = MaxAbsScaler()
263-
264-
scaler = StandardScaler()
265-
df_x = scaler.fit_transform(df_x)
266-
267-
X_train, X_test, Y_train, Y_test = train_test_split(
268-
df_x, df_y, test_size=0.20, random_state=42
269-
)
234+
if 'benchmark_data' in params and params['benchmark_data'] != "":
235+
if params['train_data'].endswith('.parquet'):
236+
header_url = params["header_url"]
237+
dh_dict, th_list = load_headers(
238+
"descriptor_headers.csv", "training_headers.csv", header_url
239+
)
240+
offset = 6 # descriptor starts at index 6
241+
desc_col_idx = [dh_dict[key] + offset for key in th_list]
242+
243+
url = params["data_url"]
244+
245+
# file_train = params["train_data"]
246+
train_file = candle.get_file(params['train_data'], url + params['train_data'], cache_subdir="Pilot1")
247+
test_file = candle.get_file(params['test_data'], url + params['test_data'], cache_subdir="Pilot1")
248+
val_file = candle.get_file(params['val_data'], url + params['val_data'], cache_subdir="Pilot1")
249+
250+
# df = (pd.read_csv(data_path,skiprows=1).values).astype('float32')
251+
print("Loading data...")
252+
train_df = pd.read_parquet(train_file)
253+
val_df = pd.read_parquet(val_file)
254+
test_df = pd.read_parquet(test_file)
255+
print("done")
256+
257+
train_df_y = train_df["reg"].astype("float32")
258+
train_df_x = train_df.iloc[:, desc_col_idx].astype(np.float32)
259+
test_df_y = test_df["reg"].astype("float32")
260+
test_df_x = test_df.iloc[:, desc_col_idx].astype(np.float32)
261+
val_df_y = val_df["reg"].astype("float32")
262+
val_df_x = val_df.iloc[:, desc_col_idx].astype(np.float32)
263+
264+
bins = np.arange(0, 20)
265+
histogram, bin_edges = np.histogram(train_df_y, bins=bins, density=False)
266+
print("Histogram of samples (bins, counts)")
267+
print(bin_edges)
268+
print(histogram)
269+
270+
scaler = StandardScaler()
271+
scaler.fit(train_df_x)
272+
train_df_x = scaler.fit_transform(train_df_x)
273+
test_df_x = scaler.fit_transform(test_df_x)
274+
val_df_x = scaler.fit_transform(val_df_x)
275+
276+
return train_df_x, train_df_y, val_df_x, val_df_y, train_df_x.shape[1], histogram
277+
#return X_train, Y_train, X_test, Y_test, X_train.shape[1], histogram
278+
279+
else:
280+
header_url = params["header_url"]
281+
dh_dict, th_list = load_headers(
282+
"descriptor_headers.csv", "training_headers.csv", header_url
283+
)
284+
offset = 6 # descriptor starts at index 6
285+
desc_col_idx = [dh_dict[key] + offset for key in th_list]
286+
287+
url = params["data_url"]
288+
file_train = (
289+
"ml." + params["base_name"] + ".Orderable_zinc_db_enaHLL.sorted.4col.dd.parquet"
290+
)
291+
# file_train = params["train_data"]
292+
train_file = candle.get_file(file_train, url + file_train, cache_subdir="Pilot1")
293+
# df = (pd.read_csv(data_path,skiprows=1).values).astype('float32')
294+
print("Loading data...")
295+
df = pd.read_parquet(train_file)
296+
print("done")
297+
298+
# df_y = df[:,0].astype('float32')
299+
df_y = df["reg"].astype("float32")
300+
# df_x = df[:, 1:PL].astype(np.float32)
301+
df_x = df.iloc[:, desc_col_idx].astype(np.float32)
302+
303+
bins = np.arange(0, 20)
304+
histogram, bin_edges = np.histogram(df_y, bins=bins, density=False)
305+
print("Histogram of samples (bins, counts)")
306+
print(bin_edges)
307+
print(histogram)
308+
309+
# scaler = MaxAbsScaler()
310+
311+
scaler = StandardScaler()
312+
df_x = scaler.fit_transform(df_x)
313+
314+
X_train, X_test, Y_train, Y_test = train_test_split(
315+
df_x, df_y, test_size=0.20, random_state=42
316+
)
270317

271318
print("x_train shape:", X_train.shape)
272319
print("x_test shape:", X_test.shape)

0 commit comments

Comments
 (0)