You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I need to find the maximum of the MyFunc function, which has 532 input variables. But when I let JAX compute the gradient, the result is a zero gradient - which doesn't help me at all in my task. If I understand it correctly the reason for this is that the functions uses sort orders (comparisions).
I would like to politely ask if someone could help me to rewrite the MyFunction so that it works in the same way, but gives a meaningful gradients. Thank you very much.
from jax import grad
import jax.numpy as jnp
import numpy as np
import json
# Example data - in real case I have 150,000+ rows
data = jnp.array([[ 1. , 1.06, 9.77, 5. , 3. , 2. , 6. , 12. , 4. ,
10. , 1. , 7. , 1. , 12. , 10. , 12. , 4. , 10. ,
8. , 7. , 11. , 5. , 9. , 3. , 6. , 12. , 6. ,
5. , 3. , 5. , 9. , 8. , 9. , 10. , 11. , 12. ,
1. , 2. , 3. , 4. , 5. , 6. , 7. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[ 1. , 1.33, 3.33, 5. , 3. , 2. , 6. , 12. , 4. ,
10. , 1. , 7. , 1. , 12. , 10. , 12. , 4. , 10. ,
8. , 7. , 11. , 5. , 9. , 3. , 6. , 12. , 6. ,
5. , 3. , 5. , 9. , 8. , 9. , 10. , 11. , 12. ,
1. , 2. , 3. , 4. , 5. , 6. , 7. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[ 2. , 1.65, 2.07, 5. , 3. , 2. , 6. , 12. , 4. ,
10. , 1. , 7. , 1. , 12. , 10. , 12. , 4. , 8. ,
6. , 5. , 9. , 3. , 7. , 1. , 4. , 10. , 4. ,
3. , 1. , 3. , 7. , 10. , 11. , 12. , 1. , 2. ,
3. , 4. , 5. , 6. , 7. , 8. , 9. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ]])
def MyFunction(coefs, data):
balance = float(len(data)*-1000)
for row in data:
result = row[0]
fOdds = row[1]
dOdds = row[2]
h1P = 0.0
h2P = 0.0
h3P = 0.0
h4P = 0.0
h5P = 0.0
h6P = 0.0
h7P = 0.0
h8P = 0.0
h9P = 0.0
h10P = 0.0
h11P = 0.0
h12P = 0.0
for p in range (0, 14):
s = int(row[3+p]-1)
h = int(row[17+p])
r = int(row[43+p])
bCoef = coefs[p]
sCoef = coefs[14 + (p * 12) + s]
hCoef = coefs[182 + (p * 12) + h]
if r == 1:
rCoef = coefs[350 + p]
else:
rCoef = 1.0
pStrength = bCoef * sCoef * hCoef * rCoef
if h == 0:
h1P += pStrength
if h == 1:
h2P += pStrength
if h == 2:
h3P += pStrength
if h == 3:
h4P += pStrength
if h == 4:
h5P += pStrength
if h == 5:
h6P += pStrength
if h == 6:
h7P += pStrength
if h == 7:
h8P += pStrength
if h == 8:
h9P += pStrength
if h == 9:
h10P += pStrength
if h == 10:
h11P += pStrength
if h == 11:
h12P += pStrength
for h in range (0, 12):
hSign = int(row[31+h]-1)
if h == 0:
h1P *= coefs [364 + (h*12) + hSign]
if h == 1:
h2P *= coefs [364 + (h*12) + hSign]
if h == 2:
h3P *= coefs [364 + (h*12) + hSign]
if h == 3:
h4P *= coefs [364 + (h*12) + hSign]
if h == 4:
h5P *= coefs [364 + (h*12) + hSign]
if h == 5:
h6P *= coefs [364 + (h*12) + hSign]
if h == 6:
h7P *= coefs [364 + (h*12) + hSign]
if h == 7:
h8P *= coefs [364 + (h*12) + hSign]
if h == 8:
h9P *= coefs [364 + (h*12) + hSign]
if h == 9:
h10P *= coefs [364 + (h*12) + hSign]
if h == 10:
h11P *= coefs [364 + (h*12) + hSign]
if h == 11:
h12P *= coefs [364 + (h*12) + hSign]
fPoints = 0.0
dPoints = 0.0
fPoints += h1P * coefs[508]
fPoints += h2P * coefs[509]
fPoints += h3P * coefs[510]
fPoints += h4P * coefs[511]
fPoints += h5P * coefs[512]
fPoints += h6P * coefs[513]
fPoints += h7P * coefs[514]
fPoints += h8P * coefs[515]
fPoints += h9P * coefs[516]
fPoints += h10P * coefs[517]
fPoints += h11P * coefs[518]
fPoints += h12P * coefs[519]
dPoints += h1P * coefs[520]
dPoints += h2P * coefs[521]
dPoints += h3P * coefs[522]
dPoints += h4P * coefs[523]
dPoints += h5P * coefs[524]
dPoints += h6P * coefs[525]
dPoints += h7P * coefs[526]
dPoints += h8P * coefs[527]
dPoints += h9P * coefs[528]
dPoints += h10P * coefs[529]
dPoints += h11P * coefs[530]
dPoints += h12P * coefs[531]
if result == 1:
if fPoints >= dPoints:
balance += fOdds*1000
elif result == 2:
if dPoints > fPoints:
balance += dOdds*1000
return balance
derivFunction = grad (MyFunction)
coefs = np.random.sample(532)
# here I just get a list of 532 zeros instead of the derivatives...
print (derivFunction(coefs, data))
coefs = np.random.sample(532)
print (derivFunction(coefs, data))
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I need to find the maximum of the MyFunc function, which has 532 input variables. But when I let JAX compute the gradient, the result is a zero gradient - which doesn't help me at all in my task. If I understand it correctly the reason for this is that the functions uses sort orders (comparisions).
I would like to politely ask if someone could help me to rewrite the MyFunction so that it works in the same way, but gives a meaningful gradients. Thank you very much.
Beta Was this translation helpful? Give feedback.
All reactions