|
182 | 182 | "from itertools import islice\n", |
183 | 183 | "\n", |
184 | 184 | "with make_env() as env:\n", |
185 | | - " # The two datasets we will be using:\n", |
186 | | - " npb = env.datasets[\"npb-v0\"]\n", |
187 | | - " chstone = env.datasets[\"chstone-v0\"]\n", |
| 185 | + " # The two datasets we will be using:\n", |
| 186 | + " npb = env.datasets[\"npb-v0\"]\n", |
| 187 | + " chstone = env.datasets[\"chstone-v0\"]\n", |
188 | 188 | "\n", |
189 | | - " # Each dataset has a `benchmarks()` method that returns an iterator over the\n", |
190 | | - " # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n", |
191 | | - " # handful of benchmarks for training and validation.\n", |
192 | | - " train_benchmarks = list(islice(npb.benchmarks(), 55))\n", |
193 | | - " train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n", |
194 | | - " # We will use the entire chstone-v0 dataset for testing.\n", |
195 | | - " test_benchmarks = list(chstone.benchmarks())\n", |
| 189 | + " # Each dataset has a `benchmarks()` method that returns an iterator over the\n", |
| 190 | + " # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n", |
| 191 | + " # handful of benchmarks for training and validation.\n", |
| 192 | + " train_benchmarks = list(islice(npb.benchmarks(), 55))\n", |
| 193 | + " train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n", |
| 194 | + " # We will use the entire chstone-v0 dataset for testing.\n", |
| 195 | + " test_benchmarks = list(chstone.benchmarks())\n", |
196 | 196 | "\n", |
197 | 197 | "print(\"Number of benchmarks for training:\", len(train_benchmarks))\n", |
198 | 198 | "print(\"Number of benchmarks for validation:\", len(val_benchmarks))\n", |
|
221 | 221 | "from compiler_gym.wrappers import CycleOverBenchmarks\n", |
222 | 222 | "\n", |
223 | 223 | "def make_training_env(*args) -> compiler_gym.envs.CompilerEnv:\n", |
224 | | - " \"\"\"Make a reinforcement learning environment that cycles over the\n", |
225 | | - " set of training benchmarks in use.\n", |
226 | | - " \"\"\"\n", |
227 | | - " del args # Unused env_config argument passed by ray\n", |
228 | | - " return CycleOverBenchmarks(make_env(), train_benchmarks)\n", |
| 224 | + " \"\"\"Make a reinforcement learning environment that cycles over the\n", |
| 225 | + " set of training benchmarks in use.\n", |
| 226 | + " \"\"\"\n", |
| 227 | + " del args # Unused env_config argument passed by ray\n", |
| 228 | + " return CycleOverBenchmarks(make_env(), train_benchmarks)\n", |
229 | 229 | "\n", |
230 | 230 | "tune.register_env(\"compiler_gym\", make_training_env)" |
231 | 231 | ] |
|
245 | 245 | "# Lets cycle through a few calls to reset() to demonstrate that this environment\n", |
246 | 246 | "# selects a new benchmark for each episode.\n", |
247 | 247 | "with make_training_env() as env:\n", |
248 | | - " env.reset()\n", |
249 | | - " print(env.benchmark)\n", |
250 | | - " env.reset()\n", |
251 | | - " print(env.benchmark)\n", |
252 | | - " env.reset()\n", |
253 | | - " print(env.benchmark)" |
| 248 | + " env.reset()\n", |
| 249 | + " print(env.benchmark)\n", |
| 250 | + " env.reset()\n", |
| 251 | + " print(env.benchmark)\n", |
| 252 | + " env.reset()\n", |
| 253 | + " print(env.benchmark)" |
254 | 254 | ] |
255 | 255 | }, |
256 | 256 | { |
|
282 | 282 | "\n", |
283 | 283 | "# (Re)Start the ray runtime.\n", |
284 | 284 | "if ray.is_initialized():\n", |
285 | | - " ray.shutdown()\n", |
| 285 | + " ray.shutdown()\n", |
286 | 286 | "ray.init(include_dashboard=False, ignore_reinit_error=True)\n", |
287 | 287 | "\n", |
288 | 288 | "tune.register_env(\"compiler_gym\", make_training_env)\n", |
|
370 | 370 | "# performance on a set of benchmarks.\n", |
371 | 371 | "\n", |
372 | 372 | "def run_agent_on_benchmarks(benchmarks):\n", |
373 | | - " \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n", |
374 | | - " with make_env() as env:\n", |
| 373 | + " \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n", |
375 | 374 | " rewards = []\n", |
376 | | - " for i, benchmark in enumerate(benchmarks, start=1):\n", |
377 | | - " observation, done = env.reset(benchmark=benchmark), False\n", |
378 | | - " while not done:\n", |
379 | | - " action = agent.compute_action(observation)\n", |
380 | | - " observation, _, done, _ = env.step(action)\n", |
381 | | - " rewards.append(env.episode_reward)\n", |
382 | | - " print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n", |
| 375 | + " with make_env() as env:\n", |
| 376 | + " for i, benchmark in enumerate(benchmarks, start=1):\n", |
| 377 | + " observation, done = env.reset(benchmark=benchmark), False\n", |
| 378 | + " while not done:\n", |
| 379 | + " action = agent.compute_action(observation)\n", |
| 380 | + " observation, _, done, _ = env.step(action)\n", |
| 381 | + " rewards.append(env.episode_reward)\n", |
| 382 | + " print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n", |
383 | 383 | "\n", |
384 | | - " return rewards\n", |
| 384 | + " return rewards\n", |
385 | 385 | "\n", |
386 | 386 | "# Evaluate agent performance on the validation set.\n", |
387 | 387 | "val_rewards = run_agent_on_benchmarks(val_benchmarks)" |
|
417 | 417 | "outputs": [], |
418 | 418 | "source": [ |
419 | 419 | "# Finally lets plot our results to see how we did!\n", |
| 420 | + "%matplotlib inline\n", |
420 | 421 | "from matplotlib import pyplot as plt\n", |
421 | 422 | "\n", |
422 | 423 | "def plot_results(x, y, name, ax):\n", |
423 | | - " plt.sca(ax)\n", |
424 | | - " plt.bar(range(len(y)), y)\n", |
425 | | - " plt.ylabel(\"Reward (higher is better)\")\n", |
426 | | - " plt.xticks(range(len(x)), x, rotation = 90)\n", |
427 | | - " plt.title(f\"Performance on {name} set\")\n", |
| 424 | + " plt.sca(ax)\n", |
| 425 | + " plt.bar(range(len(y)), y)\n", |
| 426 | + " plt.ylabel(\"Reward (higher is better)\")\n", |
| 427 | + " plt.xticks(range(len(x)), x, rotation = 90)\n", |
| 428 | + " plt.title(f\"Performance on {name} set\")\n", |
428 | 429 | "\n", |
429 | 430 | "fig, (ax1, ax2) = plt.subplots(1, 2)\n", |
430 | 431 | "fig.set_size_inches(13, 3)\n", |
|
0 commit comments