Skip to content

Commit fd7f913

Browse files
committed
back to sgd for lin vec.
1 parent 9c1a3f6 commit fd7f913

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vsms/search_loop_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def fit_rank2(*, mod, X, y, batch_size, max_examples, valX=None, valy=None, logg
301301

302302
def adjust_vec(vec, Xt, yt, learning_rate, loss_margin, max_examples, minibatch_size):
303303
vec = torch.from_numpy(vec).type(torch.float32)
304-
mod = LookupVec(Xt.shape[1], margin=loss_margin, optimizer=torch.optim.Adam, learning_rate=learning_rate, init_vec=vec)
304+
mod = LookupVec(Xt.shape[1], margin=loss_margin, optimizer=torch.optim.SGD, learning_rate=learning_rate, init_vec=vec)
305305
fit_rank2(mod=mod, X=Xt.astype('float32'), y=yt.astype('float'),
306306
max_examples=max_examples, batch_size=minibatch_size,max_epochs=1)
307307
newvec = mod.vec.detach().numpy().reshape(1,-1)

0 commit comments

Comments
 (0)