-
I've implemented the REINFORCE algorithm using PyTorch. I planned to port it to JAX/Flax. While doing so, I'd stuck at a problem. What the problem is? I'll attach my both PyTorch implementation and partially completed JAX/Flax version below. Pls share JAX version of REINFORCE algorithm, if you came across. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 8 replies
-
You can store # collect states, actions and returns from agent in environment
def compute_loss(params):
loss = 0
for state, action, ret in zip(states, actions, returns):
logits = policy_network.apply(params, state)
log_prob = log_prob_func(logits, action) # i.e. log_softmax(logits)[action]
loss = loss + ret * log_prob
return -loss Or # collect states, actions and returns from agent in environment, and stack to array
def compute_loss(params):
log_probs = jax.vmap(log_prob_func)(jax.vmap(lambda s: policy_network.apply(params, s))(states), actions)
return jnp.sum(-returns * log_probs) |
Beta Was this translation helpful? Give feedback.
-
Since REINFORCE is an on-policy algorithm, it throws off the collected trajectories after updating the policy(which happens at the end of each episode). def compute_loss(env,policy,params,nb_timesteps,discount_factor=0.99):
state = env.reset()
rewards = []
log_probs = []
for timestep in range(nb_timesteps):
action,log_prob = act(policy,params,state) #samples an action and computes log_probability of that action
state,reward,done,info = env.step(action)
log_probs.append(log_prob)
rewards.append(reward)
if done:
break
#calculate return using rewards & discount factor
loss = jnp.asarray(log_probs)*jnp.asarray(returns)
loss = jnp.sum(loss)
return -loss Then, use |
Beta Was this translation helpful? Give feedback.
-
There are two possible solutions for this problem. The first solution is suggested by @YouJiacheng.
The second solution is suggested by me (@BalajiAI) which is efficient than the first solution interms of both time & memory, since we don't have to recompute the output (logits) for every states using the NN.
|
Beta Was this translation helpful? Give feedback.
There are two possible solutions for this problem.
The first solution is suggested by @YouJiacheng.