Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions src/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

import yaml
import time
import os
import os, sys, shutil
import math
from sklearn.model_selection._search import ParameterSampler, ParameterGrid
import numpy as np
import subprocess
import random
from simulate import simulate
import scipy.stats.distributions as dists

import argparse
parser = argparse.ArgumentParser(description='Generate ball simulations')
Expand All @@ -25,11 +26,16 @@
help='Path to save simulations')
args = parser.parse_args()

num_timesteps_per_simulation = 100
num_timesteps_per_simulation = 60
image_size = 64

# Parameters which can change per ball
per_ball_parameter_space = {
#"x":dists.uniform(0.1,0.9),
#"y":dists.uniform(0.1,0.9),
#"dx":dists.uniform(-0.5,0.4),
#"dy":dists.uniform(-0.5,0.4),
#"radius":dists.uniform(0.1,0.3),
"x":[0.1*i for i in range(1,10)],
"y":[0.1*i for i in range(1,10)],
"dx":[0.1*i for i in range(-5,5)],
Expand All @@ -38,20 +44,37 @@
"foreground_color":['#{0:x}{0:x}{0:x}'.format(x) for x in range(15,-1,-1)],
}

#per_ball_parameter_space = {
# "x":dists.uniform(0.1,0.9),
# "y":dists.uniform(0.1,0.9),
# "dx":dists.uniform(-0.5,0.4),
# "dy":dists.uniform(-0.5,0.4),
# "radius":dists.uniform(0.1,0.3),
# "foreground_color":['#{0:x}{0:x}{0:x}'.format(x) for x in range(15,-1,-1)]
#}

# Global parameters
parameter_space = {
"num_balls":range(1,4),
# "gx":[0.1*i for i in range(0,5)],
# "gy":[0.1*i for i in range(0,5)],
"num_balls":range(1,5),
#"gx":dists.uniform(0,0.04), # As long as at least one parameter is a distribution, the sampler will sample WITHOUT REPLACEMENT (important)
"gy":dists.uniform(0,0.02),
"background_color":['#{0:x}{0:x}{0:x}'.format(x) for x in range(0,16)],
}

## How the rest of the parameter space make look if you wish to sample from continuous space
#parameter_space = {
# "num_balls":range(1,4), # YOU CANNOT MAKE NUMBALLS A `dists.randint' distribution without changing your ``Building space'' code
# "gx":dists.uniform(0,0.4),
# "gy":dists.uniform(0,0.4),
# "background_color":['#{0:x}{0:x}{0:x}'.format(x) for x in range(0,16)]
#}

# Make dataset dir
try:
os.mkdir(args.output_path)
pass
except:
raise Exception("Folder " + args.output_path + " already exists.")
raise Exception("Folder " + args.output_path + " already exists.")

# write config
config = {}
Expand Down Expand Up @@ -82,12 +105,19 @@
# Render
simulation_num = 0
balls = []
sampler = ParameterGrid(parameter_space)

while simulation_num < args.number_of_simulations:
random_idx = random.randint(0, len(sampler)-1)
params = sampler[random_idx]
# If one of the parameters given is a scipy.stats.distributions object, then it will sample with replacement
sampler = ParameterSampler(parameter_space, min(1000, args.number_of_simulations))
sampler = list(sampler)
sampler_idx = 0


while simulation_num < args.number_of_simulations:
if(sampler_idx == min(1000, args.number_of_simulations)):
sampler_idx = 0
sampler = list(ParameterSampler(parameter_space, min(1000, args.number_of_simulations)))
params = sampler[sampler_idx]
sampler_idx += 1
good_simulation = True
for i in range(params["num_balls"]):

Expand Down