Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@ class DistanceCalculator(override val uid: String)
val calculateDistance = udf({

(baseLocation: Row, locations: Seq[Row]) /* also coordinates */ => {

locations.map(location =>
Try(
Location.getDistance[GeoCoordinates](baseLocation, location)
).getOrElse(100000f)
)
.min
if (baseLocation == null || locations == null || locations.isEmpty) {
None
} else {
Some(locations.map(location =>
Try(
Location.getDistance[GeoCoordinates](baseLocation, location)
).getOrElse(100000f)
)
.min)
}
}
})

Expand All @@ -73,7 +76,7 @@ class DistanceCalculator(override val uid: String)
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
}
val outputFields = schema.fields :+
StructField($(outputCol), schemaFor[Float].dataType, nullable = false)
StructField($(outputCol), schemaFor[Float].dataType, nullable = true)
StructType(outputFields)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ class DistanceScorer(override val uid: String)
transformSchema(dataset.schema, logging = true)

val calculateDistance = udf{
(distance: Float, distanceFactor: Float) => scoreDistance(distance, distanceFactor)
(distance: Float, distanceFactor: Float) => {
if (distance == null) {
None
} else {
Some(scoreDistance(distance, distanceFactor))
}
}
}

dataset.withColumn($(outputCol), calculateDistance(col($(inputCol)), col($(distanceFactorCol))))
Expand All @@ -61,7 +67,7 @@ class DistanceScorer(override val uid: String)
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
}
val outputFields = schema.fields :+
StructField($(outputCol), schemaFor[Float].dataType, nullable = false)
StructField($(outputCol), schemaFor[Float].dataType, nullable = true)
StructType(outputFields)
}

Expand Down
18 changes: 13 additions & 5 deletions src/main/scala/com/haufe/umantis/ds/nlp/LinearWeigher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,26 @@ class LinearWeigher(override val uid: String)
transformSchema(dataset.schema, logging = true)

val linearCombination = udf {
(linearWeights:Row, values: Seq[Float]) => {
(values, linearWeights.getSeq[Float](1)).zipped.map(_ * _).sum / linearWeights.getFloat(0)
(linearWeights: Seq[Float], values: Seq[Any]) => {
val (valueSum, weightSum) = (values, linearWeights).zipped
.foldLeft((0.0f, 0.0f)){
case ((valueSum, weightSum), (value:Float, weights)) =>
(valueSum + value * weights, weightSum + weights)
case ((valueSum, weightSum), (_, weights)) =>
(valueSum, weightSum)
}

valueSum / weightSum
}
}
val inputColumns = array($(inputCols).map(col):_*)
val inputColumns = array($(inputCols).map(col): _*)
dataset.withColumn($(outputCol), linearCombination(col($(linearWeightsCol)), inputColumns))
}

override def transformSchema(schema: StructType): StructType = {

$(inputCols).foreach(validateColumnSchema(_,schemaFor[Float].dataType,schema))
validateColumnSchema($(linearWeightsCol), schemaFor[(Float,Array[Float])].dataType, schema)
$(inputCols).foreach(validateColumnSchema(_, schemaFor[Float].dataType, schema))
validateColumnSchema($(linearWeightsCol), schemaFor[Array[Float]].dataType, schema)

if (schema.fieldNames.contains($(outputCol))) {
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ class OtherLanguagesBooster(override val uid: String)
val boostOtherLanguage = udf {

(language: String, baseLang: String, similarity: Float) =>
if (language != baseLang)
similarity * $(boostsMap).getOrElse(baseLang, 1.0f)
if (language == null || baseLang == null)
None
else if (language != baseLang)
Some(similarity * $(boostsMap).getOrElse(baseLang, 1.0f))
else
similarity
Some(similarity)
}

dataset.withColumn($(outputCol), boostOtherLanguage(col($(languageCol)),
Expand All @@ -72,7 +74,7 @@ class OtherLanguagesBooster(override val uid: String)
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
}
val outputFields = schema.fields :+
StructField($(outputCol), schemaFor[Float].dataType, nullable = false)
StructField($(outputCol), schemaFor[Float].dataType, nullable = true)
StructType(outputFields)
}

Expand Down
37 changes: 34 additions & 3 deletions src/main/scala/com/haufe/umantis/ds/nlp/PipelineCreator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,28 @@ object ColnamesText {
}

class ColnamesTextSimilarity(val baseText: ColnamesText,
val varyingText: ColnamesText)
val varyingText: ColnamesText,
val additional: String = "")
extends Colnames {

def similarity: String =
s"${baseText.vector}__similarity__${varyingText.vector}"
s"${baseText.vector}__similarity${additional}__${varyingText.vector}"

def score: String =
s"${baseText.vector}__score__${varyingText.vector}"
s"${baseText.vector}__score${additional}__${varyingText.vector}"

def baseVector: String =
s"${baseText.vector}__baseVector${additional}__${varyingText.vector}"

def baseLanguage: String =
s"${baseText.vector}__baseLanguage${additional}__${varyingText.vector}"
}
object ColnamesTextSimilarity {
def apply(baseText: ColnamesText, varyingText: ColnamesText): ColnamesTextSimilarity =
new ColnamesTextSimilarity(baseText, varyingText)

def apply(baseText: ColnamesText, varyingText: ColnamesText, additional: String): ColnamesTextSimilarity =
new ColnamesTextSimilarity(baseText, varyingText, additional)
}

class ColnamesURL(val colName: String) extends Colnames {
Expand Down Expand Up @@ -150,7 +160,10 @@ object Stg {
val StopWordsRemover: String = "StopWordsRemover"
val EmbeddingsModel: String = "EmbeddingsModel"
val OtherLanguageBooster: String = "OtherLanguageBooster"

val SimilarityScorer: String = "SimilarityScorer"
val SimilarityScorerArrayMax: String = "SimilarityScorerArrayMax"
val SimilarityScorerArrayMean: String = "SimilarityScorerArrayMean"

val WordMoverDistance: String = "WordMoverDistance"
val NormalizedBagOfWords: String = "NormalizedBagOfWords"
Expand Down Expand Up @@ -321,6 +334,22 @@ object DsPipeline extends ConfigGetter {
.setOutputCol(c.similarity)
.setBaseVectorCol(c.baseText.vector)

def similarityScorerDenseVectorAgregator(
c: ColnamesTextSimilarity,
aggregationFunction: Seq[Float] => Float)
: SimilarityScorerMultipleBaseVectorsDenseVector =
new SimilarityScorerMultipleBaseVectorsDenseVector()
.setInputCol(c.varyingText.vector)
.setAggregationFunction(aggregationFunction)
.setOutputCol(c.similarity)
.setBaseVectorCol(c.baseVector)

def getSimilarityScorerDenseVectorMax(c: ColnamesTextSimilarity): SimilarityScorerMultipleBaseVectorsDenseVector =
similarityScorerDenseVectorAgregator(c, ColumnsAggregator.max)

def getSimilarityScorerDenseVectorMean(c: ColnamesTextSimilarity): SimilarityScorerMultipleBaseVectorsDenseVector =
similarityScorerDenseVectorAgregator(c, ColumnsAggregator.mean)

def getCoordinatesFetcher(c: ColnamesLocation): CoordinatesFetcher =
new CoordinatesFetcher()
.setLocationCol(c.location)
Expand Down Expand Up @@ -377,6 +406,8 @@ object DsPipeline extends ConfigGetter {
Stg.EmbeddingsModel -> getEmbeddingsModel _,
Stg.OtherLanguageBooster -> getOtherLanguageBooster _,
Stg.SimilarityScorer -> getSimilarityScorerDenseVector _,
Stg.SimilarityScorerArrayMax -> getSimilarityScorerDenseVectorMax _,
Stg.SimilarityScorerArrayMean -> getSimilarityScorerDenseVectorMean _,

Stg.WordMoverDistance -> getWordMoverDistance _,
Stg.NormalizedBagOfWords -> getNormalizedBagOfWordsTransformer _,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,20 @@ class SimilarityScorerDenseVector(override val uid: String)

val calculateSimilarity = udf {
(v0: DenseVector, v1: DenseVector) => {
val size = v0.size
val base = v0.values
val vector = v1.values
var i = 0
var dotProduct = 0.0
while (i < size) {
dotProduct += base(i) * vector(i)
i += 1
if (v0 == null || v1 == null) {
None
} else {
val size = v0.size
val base = v0.values
val vector = v1.values
var i = 0
var dotProduct = 0.0
while (i < size) {
dotProduct += base(i) * vector(i)
i += 1
}
Some(dotProduct.toFloat)
}
dotProduct.toFloat
}
}

Expand All @@ -69,7 +73,7 @@ class SimilarityScorerDenseVector(override val uid: String)
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
}
val outputFields = schema.fields :+
StructField($(outputCol), schemaFor[Float].dataType, nullable = false)
StructField($(outputCol), schemaFor[Float].dataType, nullable = true)
StructType(outputFields)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/**
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package com.haufe.umantis.ds.nlp

import com.haufe.umantis.ds.nlp.params._
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}


class SimilarityScorerMultipleBaseVectorsDenseVector(override val uid: String)
extends Transformer with HasInputCol with HasOutputCol
with HasAggregationFunction with HasBaseVectorCol with ValidateColumnSchema {

def this() = this(Identifiable.randomUID("SimilarityScorerDenseVector"))

def setInputCol(value: String): this.type = set(inputCol, value)

def setOutputCol(value: String): this.type = set(outputCol, value)

private val vectorUDT = ScalaReflection.schemaFor[DenseVector].dataType

def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val aggregation: Seq[Float] => Float = $(aggregationFunction)

val calculateSimilarity = udf {
(base: Seq[DenseVector], v1: DenseVector) => {
(if (base == null || v1 == null) {
None
} else {
Some(base.map(v0 => {
val size = v0.size
val base = v0.values
val vector = v1.values
var i = 0
var dotProduct = 0.0
while (i < size) {
dotProduct += base(i) * vector(i)
i += 1
}
dotProduct.toFloat
}))
}).map(aggregation)
}
}

dataset.withColumn($(outputCol), calculateSimilarity(col($(baseVectorCol)), col($(inputCol))))
}

override def transformSchema(schema: StructType): StructType = {
validateColumnSchema($(inputCol), vectorUDT, schema)
validateColumnSchema($(baseVectorCol), vectorUDT, schema)

if (schema.fieldNames.contains($(outputCol))) {
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
}
val outputFields = schema.fields :+
StructField($(outputCol), schemaFor[Float].dataType, nullable = true)
StructType(outputFields)
}

override def copy(extra: ParamMap): SimilarityScorerMultipleBaseVectorsDenseVector =
defaultCopy[SimilarityScorerMultipleBaseVectorsDenseVector](extra)

}
Loading