-
Notifications
You must be signed in to change notification settings - Fork 94
Open
Description
First of all: I have little prior experience with tensorflow; in fact I'm trying to understand it by using your package so I can work in Scala rather than python. Thank you very much for your library, it helps a ton, since I'm a lot more comfortable working within Scala.
As a first step, I figured I'd translate a GloVe-implementation from python to scala, namely: https://github.com/GradySimon/tensorflow-glove/blob/master/tf_glove.py (my translation attached below; the relevant parts are all in the train-method, lines 98ff)
I encountered two issues:
- In line 99 of the python implementation it says
embedding_product = tf.reduce_sum(tf.multiply(focal_embedding, context_embedding), 1)
but I had to change the axis to -1 (as in
val embedding_product = tf.sum(tf.multiply(focal_embedding,context_embedding),-1)), otherwise I'd get an errorInputs to operation AddN of type AddN must have the same size and shape. Input 0: [512,50] != input 1: [512,1]- caused by the computation ofdistance_expr(line 103 python, line 124 in the scala code below). Is there a discrepancy between the python implementation and yours regarding the axes? - Having "fixed" that, I now get an error
Input to reshape is a tensor with 262144 values, but the requested shape has 512 [[{{node Gradients/SumGradient/Reshape_1}}]]that seems to occur somewhere in theAdaGrad.minimize-computation. I'm at a loss (pun intended) where this comes from - the graph as implemented seems to have the correct dimensionalities everywhere (all the tensors in sums seem to have the correct shapeShape(512)), so I wouldn't know what's wrong other than that there's an error somewhere in the implementation of the gradients involved... any help would be most welcome.
package com.jazzpirate.glove
import com.jazzpirate.tensorflow._
import info.kwarc.mmt.api.utils.File
import org.platanios.tensorflow.api._
import org.platanios.tensorflow.api.core.client.FeedMap
import org.platanios.tensorflow.api.ops.Embedding.DivStrategy
import org.platanios.tensorflow.api.ops.Output
import org.platanios.tensorflow.api.tf
import org.platanios.tensorflow.api.ops.variables.RandomUniformInitializer
import org.platanios.tensorflow.api.tensors.Tensor
import scala.collection.parallel.mutable
class TFGloVe() {
private object Vocab {
private var itos_map : List[String] = Nil
def setElems(ls:List[String]) = {
itos_map = ls
stoi_map = mutable.ParHashMap.empty
w = None
_size = itos_map.length
itos_map.zipWithIndex.foreach {
case (s,i) => stoi_map(s) = i//+1
}
}
def load(file:File) = {
val vocab = file.addExtension("vocab")
val itos_map = File.read(vocab).split('\t')
stoi_map = mutable.ParHashMap.empty
_size = itos_map.length
itos_map.zipWithIndex.foreach {
case (s,i) => stoi_map(s) = i//+1
}
val matrix = vocab.setExtension("matrix")
???
// w = Some(Nd4j.readTxt(matrix.toString()))
}
def setMatrix(matrix:scala.collection.mutable.ArrayBuffer[MutableVector]) = w = Some(matrix)
private var stoi_map :mutable.ParHashMap[String,Int] = mutable.ParHashMap.empty
def stoi(s:String) = stoi_map(s)
def itos(i:Int) = itos_map(i)//-1)
var _size = 0
def size = _size
var w : Option[scala.collection.mutable.ArrayBuffer[MutableVector]] = None
def getValue(s:String) = {
val i = stoi(s)
w.get(i) + w.get(i+size)
// w.get.getRow(i).add(w.get.getRow(i+size))
}
// lazy val aarrs = (0 until (w.get.rows()/2)).map(i => w.get.getRow(i).add(w.get.getRow(i+size)))
def getString(a : Tensor[Double]) = {
??? // something with aarrs
}
def save(file:File) = {
val vocab = file.addExtension("vocab")
vocab.createNewFile()
File.write(vocab,itos_map.mkString("\t"))
val matrix = vocab.setExtension("matrix")
matrix.createNewFile()
???
// Nd4j.writeTxt(w.get,matrix.toString())
}
}
def save(file:File) = Vocab.save(file)
def load(file:File) = Vocab.load(file)
def apply(s:String) = Vocab.getValue(s)
def train(ls:List[List[String]],
windowsize:Int=10,
minval:Double=1,
batch_size:Int=512,
learning_rate:Double = 0.05,
alpha:Double=0.75,
x_max:Double=100,
vector_size:Int = 50,
iterations:Int = 25) : Unit = {
val distinct = ls.flatten.distinct
Vocab.setElems(distinct)
println(distinct.length + " Tokens")
val hm = computeMatrix(ls,windowsize,minval)
train(hm,batch_size,vector_size,iterations,x_max,alpha,learning_rate)
}
private def train(cooc:mutable.ParHashMap[(Int,Int),Double], batch_size : Int, vector_size:Int, iterations:Int,x_max:Double,alpha:Double,learning_rate:Double): Unit = {
println("Graph building...")
val session = core.client.Session()
val init = RandomUniformInitializer(-1f,1f)
val count_max = tf.constant[Float](Tensor(x_max.toFloat),name="max_cooccurence_cap")
val scaling_factor = tf.constant[Float](Tensor(alpha.toFloat),name="scaling_factor")
val focal_input = tf.placeholder[Int](Shape(batch_size),name="focal_words")
val context_input = tf.placeholder[Int](Shape(batch_size),name="context_words")
val cooccurrence_count = tf.placeholder[Float](Shape(batch_size),name="cooccurrrence_count")
val focal_embeddings = tf.variable[Float](name="focal_embeddings",Shape(Vocab.size,vector_size),init)
val context_embeddings = tf.variable[Float](name="context_embeddings",Shape(Vocab.size,vector_size),init)
val focal_biases = tf.variable[Float]("focal_biases",Shape(Vocab.size),init)
val context_biases = tf.variable[Float]("context_biases",Shape(Vocab.size),init)
val focal_embedding = org.platanios.tensorflow.api.ops.Embedding.embeddingLookup(focal_embeddings,focal_input,DivStrategy)
val context_embedding = org.platanios.tensorflow.api.ops.Embedding.embeddingLookup(context_embeddings,context_input,DivStrategy)
val focal_bias = org.platanios.tensorflow.api.ops.Embedding.embeddingLookup(focal_biases,focal_input,DivStrategy)
val context_bias = org.platanios.tensorflow.api.ops.Embedding.embeddingLookup(context_biases,context_input,DivStrategy)
val weighting_factor = tf.minimum(1.0f,tf.pow(tf.divide(cooccurrence_count,count_max),scaling_factor))
val embedding_product = tf.sum(tf.multiply(focal_embedding,context_embedding),-1)
val log_cooccurences = tf.log(cooccurrence_count.toFloat)
val distance_expr = tf.square(tf.addN(Seq(embedding_product,focal_bias,context_bias,tf.negate(log_cooccurences))))
val single_losses = tf.multiply(weighting_factor,distance_expr)
val total_loss = tf.sum(single_losses)
tf.summary.scalar("GloVe_loss",total_loss)
val optimizer = tf.train.AdaGrad(learning_rate.toFloat).minimize(total_loss)
val combined_embeddings = tf.add(focal_embeddings,context_embeddings,name="combined_embeddings")
println("Training...")
session.run(targets=Set(tf.globalVariablesInitializer()))
val writer = tf.summary.FileWriter(File("/home/jazzpirate/work/Scala/ML/data/tf").toPath,session.graph)
val (allis,alljs,allcs) = cooc.toParArray.map{case ((i,j),c) => (Tensor(i),Tensor(j),Tensor(c.toFloat))}.unzip3
var data = (0 until Vocab.size by batch_size).map(ii => (
Tensor(allis.slice(ii,ii+batch_size).toArray:_*),
Tensor(alljs.slice(ii,ii+batch_size).toArray:_*),
Tensor(allcs.slice(ii,ii+batch_size).toArray:_*)
))
/*
var data = (0 until Vocab.size).flatMap(i => (0 until Vocab.size).map{j =>
val counts = if (i<j) cooc.getOrElse((i,j),0.0) else cooc.getOrElse((j,i),0.0)
(0 until Vocab.size by batch_size).map(ii => ()) Tensor()
})
*/
val max = Vocab.size*Vocab.size
(1 to iterations).foreach { it =>
print("Iteration " + it + "/" + iterations + "... ")
data = scala.util.Random.shuffle(data)
data.indices.foreach { ii =>
print("\rIteration " + it + "/" + iterations + ": " + (ii + 1) + "/" + max + " ")
val (is,js,counts) = data(ii)
//val counts = if (is<js) cooc.getOrElse((is,js),0.0) else cooc.getOrElse((js,is),0.0)
val feed_dict = Map((focal_input,is),(context_input,js),(cooccurrence_count,counts)).asInstanceOf[Map[Output[_], Tensor[_]]]
session.run(targets=Set(optimizer),feeds=FeedMap(feed_dict))
}
}
val ret = session.run(fetches=Seq(combined_embeddings))
writer.close()
ret
// Vocab.setMatrix(w)
}
private def computeMatrix(ls:List[List[String]],windowsize:Int,minval:Double) = {
val hm = mutable.ParHashMap[(Int,Int),Double]()
def update(p1:Int,p2:Int,inc:Double) : Unit = {
if (p2<p1) update(p2,p1,inc) else {
val old = hm.getOrElse((p1,p2),0.0)
hm.update((p1,p2),old+inc)
}
}
ls foreach {region =>
region.indices.foreach { ri =>
val start = ri - windowsize
val end = ri + windowsize
val word = Vocab.stoi(region(ri))
val left = region.slice(Math.max(start,0),Math.min(ri-1,region.length+1) + 1)
val right = region.slice(Math.max(ri+1,0),Math.min(end,region.length+1) + 1)
left.reverse.zipWithIndex.foreach { case (s,i) =>
val cont_word = Vocab.stoi(s)
update(word,cont_word,1.0/(i+1))
}
right.zipWithIndex.foreach { case (s,i) =>
val cont_word = Vocab.stoi(s)
update(word,cont_word,1.0/(i+1))
}
}
}
hm.foreach{ case (k,v) if v<=minval => hm.remove(k) case _ => }
hm
}
}
Metadata
Metadata
Assignees
Labels
No labels