Skip to content

Commit dfe2eee

Browse files
committed
add: utils, methods to calculate piecewiseLinearCompression
1 parent a0d08d7 commit dfe2eee

File tree

1 file changed

+253
-0
lines changed

1 file changed

+253
-0
lines changed
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
package org.apache.sysds.runtime.compress.colgroup.functional;
2+
3+
import org.apache.sysds.runtime.compress.CompressionSettings;
4+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
5+
6+
import java.util.ArrayList;
7+
import java.util.Arrays;
8+
import java.util.Collections;
9+
import java.util.List;
10+
11+
public class PiecewiseLinearUtils {
12+
13+
private PiecewiseLinearUtils() {
14+
15+
}
16+
17+
public static final class SegmentedRegression {
18+
private final int[] breakpoints;
19+
private final double[] slopes;
20+
private final double[] intercepts;
21+
22+
public SegmentedRegression(int[] breakpoints, double[] slopes, double[] intercepts) {
23+
this.breakpoints = breakpoints;
24+
this.slopes = slopes;
25+
this.intercepts = intercepts;
26+
}
27+
28+
public int[] getBreakpoints() {
29+
return breakpoints;
30+
}
31+
32+
public double[] getSlopes() {
33+
return slopes;
34+
}
35+
36+
public double[] getIntercepts() {
37+
return intercepts;
38+
}
39+
}
40+
41+
public static SegmentedRegression compressSegmentedLeastSquares(double[] column, CompressionSettings cs) {
42+
//compute Breakpoints for a Column with dynamic Programming
43+
final List<Integer> breakpointsList = computeBreakpoints(cs, column);
44+
final int[] breakpoints = breakpointsList.stream().mapToInt(Integer::intValue).toArray();
45+
46+
//get values for Regression
47+
final int numSeg = breakpoints.length - 1;
48+
final double[] slopes = new double[numSeg];
49+
final double[] intercepts = new double[numSeg];
50+
51+
// Regress per Segment
52+
for (int seg = 0; seg < numSeg; seg++) {
53+
final int SegStart = breakpoints[seg];
54+
final int SegEnd = breakpoints[seg + 1];
55+
56+
final double[] line = regressSegment(column, SegStart, SegEnd);
57+
slopes[seg] = line[0]; //slope regession line
58+
intercepts[seg] = line[1]; //intercept regression line
59+
}
60+
61+
return new SegmentedRegression(breakpoints, slopes, intercepts);
62+
}
63+
64+
public static SegmentedRegression compressSegmentedLeastSquaresV2(double[] column, CompressionSettings cs) {
65+
//compute Breakpoints for a Column with Greedy Algorithm
66+
67+
final List<Integer> breakpointsList = computeBreakpointsGreedy(column, cs);
68+
final int[] breakpoints = breakpointsList.stream().mapToInt(Integer::intValue).toArray();
69+
70+
//get values for Regression
71+
final int numSeg = breakpoints.length - 1;
72+
final double[] slopes = new double[numSeg];
73+
final double[] intercepts = new double[numSeg];
74+
75+
// Regress per Segment
76+
for (int seg = 0; seg < numSeg; seg++) {
77+
final int segstart = breakpoints[seg];
78+
final int segEnd = breakpoints[seg + 1];
79+
final double[] line = regressSegment(column, segstart, segEnd);
80+
slopes[seg] = line[0];
81+
intercepts[seg] = line[1];
82+
}
83+
return new SegmentedRegression(breakpoints,slopes, intercepts);
84+
}
85+
86+
public static double[] getColumn(MatrixBlock in, int colIndex) {
87+
final int numRows = in.getNumRows();
88+
final double[] column = new double[numRows];
89+
90+
for (int row = 0; row < numRows; row++) {
91+
column[row] = in.get(row, colIndex);
92+
}
93+
return column;
94+
}
95+
96+
public static List<Integer> computeBreakpoints(CompressionSettings cs, double[] column) {
97+
final int numElements = column.length;
98+
final double targetMSE = cs.getPiecewiseTargetLoss();
99+
100+
101+
// TODO: Maybe remove Fallback if no targetloss is given
102+
/*if (Double.isNaN(targetMSE) || targetMSE <= 0) {
103+
final double segmentPenalty = 2.0 * Math.log(numElements);
104+
return computeBreakpointsLambda(column, segmentPenalty);
105+
}*/
106+
107+
// max targetloss
108+
final double sseMax = numElements * targetMSE;
109+
double minLoss = 0.0;
110+
double maxLoss = numElements * 100.0;
111+
List<Integer> bestBreaks = null;
112+
//compute breakpoints
113+
while(maxLoss -minLoss > 1e-8) {
114+
final double currentLoss = 0.5 * (minLoss + maxLoss);
115+
final List<Integer> breaks = computeBreakpointsLambda(column, currentLoss);
116+
final double totalSSE = computeTotalSSE(column, breaks);
117+
if (totalSSE <= sseMax) {
118+
bestBreaks = breaks;
119+
minLoss = currentLoss;
120+
}
121+
else {
122+
maxLoss = currentLoss;
123+
}
124+
}
125+
126+
if (bestBreaks == null)
127+
bestBreaks = computeBreakpointsLambda(column, minLoss);
128+
129+
return bestBreaks;
130+
}
131+
132+
public static List<Integer> computeBreakpointsLambda(double[] column, double lambda) {
133+
final int numrows = column.length;
134+
final double[] costs = new double[numrows + 1]; //min Cost
135+
final int[] prevStart = new int[numrows + 1]; //previous Start
136+
costs[0] = 0.0;
137+
// Find Cost
138+
for (int rowEnd = 1; rowEnd <= numrows; rowEnd++) {
139+
costs[rowEnd] = Double.POSITIVE_INFINITY;
140+
//Test all possible Segment to find the lowest costs
141+
for (int rowStart = 0; rowStart < rowEnd; rowStart++) {
142+
//costs = current costs + segmentloss + penaltiy
143+
final double costCurrentSegment = computeSegmentCost(column, rowStart, rowEnd);
144+
final double totalCost = costs[rowStart] + costCurrentSegment + lambda;
145+
// Check if it is the better solution
146+
if (totalCost < costs[rowEnd]) {
147+
costs[rowEnd] = totalCost;
148+
prevStart[rowEnd] = rowStart;
149+
}
150+
}
151+
}
152+
//Check the optimal segmentlimits
153+
final List<Integer> segmentLimits = new ArrayList<>();
154+
int breakpointIndex = numrows;
155+
while (breakpointIndex > 0) {
156+
segmentLimits.add(breakpointIndex);
157+
breakpointIndex = prevStart[breakpointIndex];
158+
}
159+
segmentLimits.add(0);
160+
Collections.sort(segmentLimits);
161+
return segmentLimits;
162+
}
163+
164+
public static double computeSegmentCost(double[] column, int start, int end) {
165+
final int segSize = end - start;
166+
if (segSize <= 1)
167+
return 0.0;
168+
169+
final double[] ab = regressSegment(column, start, end); //Regressionline
170+
final double slope = ab[0];
171+
final double intercept = ab[1];
172+
173+
double sumSquaredError = 0.0;
174+
for (int i = start; i < end; i++) {
175+
final double rowIdx = i;
176+
final double actualValue = column[i];
177+
final double predictedValue = slope * rowIdx + intercept;
178+
final double difference = actualValue - predictedValue;
179+
sumSquaredError += difference * difference;
180+
}
181+
return sumSquaredError;
182+
}
183+
184+
public static double computeTotalSSE(double[] column, List<Integer> breaks) {
185+
double total = 0.0;
186+
for (int s = 0; s < breaks.size() - 1; s++) {
187+
final int start = breaks.get(s);
188+
final int end = breaks.get(s + 1);
189+
total += computeSegmentCost(column, start, end);
190+
}
191+
return total;
192+
}
193+
194+
public static double[] regressSegment(double[] column, int start, int end) {
195+
final int numElements = end - start;
196+
if (numElements <= 0)
197+
return new double[] {0.0, 0.0};
198+
199+
double sumOfRowIndices = 0, sumOfColumnValues = 0, sumOfRowIndicesSquared = 0, productRowIndexTimesColumnValue = 0;
200+
for (int i = start; i < end; i++) {
201+
final double x = i;
202+
final double y = column[i];
203+
sumOfRowIndices += x;
204+
sumOfColumnValues += y;
205+
sumOfRowIndicesSquared += x * x;
206+
productRowIndexTimesColumnValue += x * y;
207+
}
208+
209+
final double numPointsInSegmentDouble = numElements;
210+
final double denominatorForSlope = numPointsInSegmentDouble * sumOfRowIndicesSquared - sumOfRowIndices * sumOfRowIndices;
211+
final double slope;
212+
final double intercept;
213+
if (denominatorForSlope == 0) {
214+
slope = 0.0;
215+
intercept = sumOfColumnValues / numPointsInSegmentDouble;
216+
}
217+
else {
218+
slope = (numPointsInSegmentDouble * productRowIndexTimesColumnValue - sumOfRowIndices * sumOfColumnValues) / denominatorForSlope;
219+
intercept = (sumOfColumnValues - slope * sumOfRowIndices) / numPointsInSegmentDouble;
220+
}
221+
return new double[] {slope, intercept};
222+
}
223+
public static List<Integer> computeBreakpointsGreedy(double[] column, CompressionSettings cs) {
224+
final int numElements = column.length;
225+
final double targetMSE = cs.getPiecewiseTargetLoss();
226+
if (Double.isNaN(targetMSE) || targetMSE <= 0) {
227+
return Arrays.asList(0, numElements); // Fallback: ein Segment
228+
}
229+
230+
List<Integer> breakpoints = new ArrayList<>();
231+
breakpoints.add(0);
232+
int currentStart = 0;
233+
234+
while (currentStart < numElements) {
235+
int bestEnd = numElements; // Default: Rest als Segment
236+
for (int end = currentStart + 1; end <= numElements; end++) {
237+
double sse = computeSegmentCost(column, currentStart, end);
238+
double sseMax = (end - currentStart) * targetMSE;
239+
if (sse > sseMax) {
240+
bestEnd = end - 1; // Letzter gültiger Endpunkt
241+
break;
242+
}
243+
}
244+
breakpoints.add(bestEnd);
245+
currentStart = bestEnd;
246+
}
247+
248+
if (breakpoints.get(breakpoints.size() - 1) != numElements) {
249+
breakpoints.add(numElements);
250+
}
251+
return breakpoints;
252+
}
253+
}

0 commit comments

Comments
 (0)