Skip to content

Commit e7426e1

Browse files
authored
Merge pull request #570 from xristlamp/main
Add LogisticRegressionOperator with API training support and integration test
2 parents 36fe918 + 160d04e commit e7426e1

File tree

6 files changed

+167
-4
lines changed

6 files changed

+167
-4
lines changed

python/src/pywy/basic/model/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,9 @@ def __init__(self, out: Op):
2727

2828
def get_out(self):
2929
return self.out
30+
31+
class LogisticRegression(Op):
32+
def __init__(self, name=None):
33+
super().__init__(Op.DType.FLOAT32, name)
34+
35+

python/src/pywy/dataquanta.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from pywy.operators import *
2424
from pywy.basic.data.record import Record
2525
from pywy.basic.model.option import Option
26-
from pywy.basic.model.models import Model
26+
from pywy.basic.model.models import (Model, LogisticRegression)
27+
2728

2829

2930
class Configuration:
@@ -193,6 +194,18 @@ def predict(
193194
that._connect(op, 1)
194195
)
195196

197+
198+
def train_logistic_regression(
199+
self: "DataQuanta[In]",
200+
labels: "DataQuanta[In]",
201+
fit_intercept: bool = True
202+
) -> "DataQuanta[Out]":
203+
op = LogisticRegression()
204+
self._connect(op, 0)
205+
labels._connect(op, 1)
206+
return DataQuanta(self.context, op)
207+
208+
196209
def store_textfile(self: "DataQuanta[In]", path: str, input_type: GenericTco = None) -> None:
197210
last: List[SinkOperator] = [
198211
cast(
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
from pywy.dataquanta import WayangContext
20+
from pywy.platforms.java import JavaPlugin
21+
from pywy.platforms.spark import SparkPlugin
22+
23+
class TestTrainLogisticRegression(unittest.TestCase):
24+
25+
def test_train_and_predict(self):
26+
ctx = WayangContext().register({JavaPlugin, SparkPlugin})
27+
28+
features = ctx.load_collection([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]])
29+
labels = ctx.load_collection([1.0, 1.0, 0.0, 0.0])
30+
31+
model = features.train_logistic_regression(labels)
32+
predictions = model.predict(features)
33+
34+
result = predictions.collect()
35+
print("Predictions:", result)
36+
37+
self.assertEqual(len(result), 4)
38+
for pred in result:
39+
self.assertIn(pred, [0.0, 1.0])
40+
41+
if __name__ == "__main__":
42+
unittest.main()

wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.wayang.core.plan.wayangplan._
3737
import org.apache.wayang.core.platform.Platform
3838
import org.apache.wayang.core.util.{Tuple => WayangTuple}
3939
import org.apache.wayang.basic.data.{Tuple2 => WayangTuple2}
40-
import org.apache.wayang.basic.model.DLModel;
40+
import org.apache.wayang.basic.model.{DLModel, LogisticRegressionModel};
4141
import org.apache.wayang.commons.util.profiledb.model.Experiment
4242
import com.google.protobuf.ByteString;
4343
import org.apache.wayang.api.python.function._
@@ -105,6 +105,17 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
105105
udfLoad: LoadProfileEstimator = null): DataQuanta[NewOut] =
106106
mapPartitionsJava(toSerializablePartitionFunction(udf), selectivity, udfLoad)
107107

108+
109+
def trainLogisticRegression(labels: DataQuanta[java.lang.Double], fitIntercept: Boolean): DataQuanta[LogisticRegressionModel] = {
110+
val operator = new LogisticRegressionOperator(fitIntercept)
111+
this.connectTo(operator, 0)
112+
labels.connectTo(operator, 1)
113+
operator
114+
}
115+
116+
117+
118+
108119
/**
109120
* Feed this instance into a [[MapPartitionsOperator]].
110121
*

wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import java.util.{Collection => JavaCollection}
2727
import org.apache.wayang.api.graph.{Edge, EdgeDataQuantaBuilder, EdgeDataQuantaBuilderDecorator}
2828
import org.apache.wayang.api.util.{DataQuantaBuilderCache, TypeTrap}
2929
import org.apache.wayang.basic.data.{Record, Tuple2 => RT2}
30-
import org.apache.wayang.basic.model.{DLModel, Model}
31-
import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator}
30+
import org.apache.wayang.basic.model.{DLModel, Model, LogisticRegressionModel}
31+
import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator, LogisticRegressionOperator}
3232
import org.apache.wayang.commons.util.profiledb.model.Experiment
3333
import org.apache.wayang.core.function.FunctionDescriptor.{SerializableBiFunction, SerializableBinaryOperator, SerializableFunction, SerializableIntUnaryOperator, SerializablePredicate}
3434
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval
@@ -38,6 +38,9 @@ import org.apache.wayang.core.plan.wayangplan.{Operator, OutputSlot, UnarySource
3838
import org.apache.wayang.core.platform.Platform
3939
import org.apache.wayang.core.types.DataSetType
4040
import org.apache.wayang.core.util.{Logging, ReflectionUtils, WayangCollections, Tuple => WayangTuple}
41+
import org.apache.wayang.core.plan.wayangplan.OutputSlot
42+
43+
4144

4245
import scala.collection.mutable.ListBuffer
4346
import scala.reflect.ClassTag
@@ -288,6 +291,12 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
288291
option: DLTrainingOperator.Option) =
289292
new DLTrainingDataQuantaBuilder(this, that, model, option)
290293

294+
def trainLogisticRegression(that: DataQuantaBuilder[_, java.lang.Double], fitIntercept: Boolean = true): LogisticRegressionDataQuantaBuilder =
295+
new LogisticRegressionDataQuantaBuilder(this.asInstanceOf[DataQuantaBuilder[_, Array[Double]]], that, fitIntercept)
296+
297+
298+
299+
291300
/**
292301
* Feed the built [[DataQuanta]] of this and the given instance into a
293302
* [[org.apache.wayang.basic.operators.PredictOperator]].
@@ -298,6 +307,8 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
298307
def predict[ThatOut, Result](that: DataQuantaBuilder[_, ThatOut], resultType: Class[Result]) =
299308
new PredictDataQuantaBuilder(this.asInstanceOf[DataQuantaBuilder[_, Model]], that, resultType)
300309

310+
311+
301312
/**
302313
* Feed the built [[DataQuanta]] of this and the given instance into a
303314
* [[org.apache.wayang.basic.operators.CoGroupOperator]].
@@ -1765,6 +1776,33 @@ class FakeDataQuantaBuilder[T](_dataQuanta: DataQuanta[T])(implicit javaPlanBuil
17651776
override protected def build: DataQuanta[T] = _dataQuanta
17661777
}
17671778

1779+
/**
1780+
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.LogisticRegressionOperator]]s.
1781+
*
1782+
* @param inputDataQuanta0 [[DataQuantaBuilder]] για τα χαρακτηριστικά (features)
1783+
* @param inputDataQuanta1 [[DataQuantaBuilder]] για τις ετικέτες (labels)
1784+
*/
1785+
class LogisticRegressionDataQuantaBuilder(inputDataQuanta0: DataQuantaBuilder[_, Array[Double]],
1786+
inputDataQuanta1: DataQuantaBuilder[_, java.lang.Double],
1787+
fitIntercept: Boolean = true)
1788+
(implicit javaPlanBuilder: JavaPlanBuilder)
1789+
extends BasicDataQuantaBuilder[LogisticRegressionDataQuantaBuilder, LogisticRegressionModel] {
1790+
1791+
locally {
1792+
this.outputTypeTrap.dataSetType = dataSetType[LogisticRegressionModel]
1793+
}
1794+
1795+
override protected def build: DataQuanta[LogisticRegressionModel] =
1796+
inputDataQuanta0
1797+
.dataQuanta()
1798+
.trainLogisticRegression(inputDataQuanta1.dataQuanta(), fitIntercept)
1799+
1800+
1801+
}
1802+
1803+
1804+
1805+
17681806
/**
17691807
* This is not an actual [[DataQuantaBuilder]] but rather decorates such a [[DataQuantaBuilder]] with a key.
17701808
*/

wayang-tests-integration/src/test/java/org/apache/wayang/tests/SparkIntegrationIT.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import org.apache.wayang.basic.WayangBasics;
2222
import org.apache.wayang.basic.data.Tuple2;
2323
import org.apache.wayang.basic.operators.*;
24+
import org.apache.wayang.basic.model.LogisticRegressionModel;
25+
import org.apache.wayang.api.DataQuanta;
26+
import org.apache.wayang.api.JavaPlanBuilder;
2427
import org.apache.wayang.core.api.Configuration;
2528
import org.apache.wayang.core.api.Job;
2629
import org.apache.wayang.core.api.WayangContext;
@@ -36,13 +39,18 @@
3639
import org.junit.Assert;
3740
import org.junit.Test;
3841

42+
43+
44+
45+
3946
import java.io.IOException;
4047
import java.net.URISyntaxException;
4148
import java.nio.file.Files;
4249
import java.nio.file.Paths;
4350
import java.util.*;
4451
import java.util.stream.Collectors;
4552
import java.util.stream.Stream;
53+
import static org.junit.jupiter.api.Assertions.*;
4654

4755
/**
4856
* Test the Spark integration with Wayang.
@@ -492,6 +500,51 @@ public void testLogisticRegressionOperator() {
492500
}
493501
}
494502

503+
@Test
504+
public void testLogisticRegressionWithAPI() {
505+
WayangContext context = new WayangContext()
506+
.with(Spark.basicPlugin())
507+
.with(Spark.mlPlugin());
508+
509+
JavaPlanBuilder planBuilder = new JavaPlanBuilder(context)
510+
.withJobName("Logistic Regression Test")
511+
.withUdfJarOf(this.getClass());
512+
513+
// Sample training data
514+
List<double[]> features = Arrays.asList(
515+
new double[]{0.0, 1.0},
516+
new double[]{1.0, 0.0},
517+
new double[]{1.0, 1.0},
518+
new double[]{0.0, 0.0}
519+
);
520+
List<Double> labels = Arrays.asList(1.0, 1.0, 0.0, 0.0);
521+
522+
// Build the pipeline using DataQuantaBuilder
523+
LogisticRegressionModel model = planBuilder
524+
.loadCollection(features).withName("Load Features")
525+
.trainLogisticRegression(
526+
planBuilder.loadCollection(labels).withName("Load Labels"),
527+
true
528+
)
529+
.collect()
530+
.iterator()
531+
.next();
532+
533+
// Predict using the model
534+
Collection<Double> predictions = planBuilder
535+
.loadCollection(Collections.singletonList(model))
536+
.predict(planBuilder.loadCollection(features), Double.class)
537+
.collect();
538+
539+
540+
assertEquals(4, predictions.size());
541+
for (double prediction : predictions) {
542+
assertTrue(prediction == 0.0 || prediction == 1.0);
543+
}
544+
}
545+
546+
547+
495548

496549
@Test
497550
public void testKMeans() {

0 commit comments

Comments
 (0)