Skip to content

Commit cdff385

Browse files
aeupmboehm7
authored andcommitted
[SYSTEMDS-3704] New resource-aware operator scheduling
Closes #2197.
1 parent 04b0d09 commit cdff385

File tree

6 files changed

+588
-5
lines changed

6 files changed

+588
-5
lines changed

src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,17 @@
2626

2727
public class IDagLinearizerFactory {
2828
public static Log LOG = LogFactory.getLog(IDagLinearizerFactory.class.getName());
29-
29+
3030
public enum DagLinearizer {
31-
DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE, AUTO, PIPELINE_DEPTH_FIRST;
31+
DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE, AUTO,
32+
PIPELINE_DEPTH_FIRST, RESOURCE_AWARE_FAST, RESOURCE_AWARE_OPTIMAL;
3233
}
3334

3435
public static IDagLinearizer createDagLinearizer() {
3536
DagLinearizer type = ConfigurationManager.getLinearizationOrder();
3637
return createDagLinearizer(type);
3738
}
38-
39+
3940
public static IDagLinearizer createDagLinearizer(DagLinearizer type) {
4041
switch(type) {
4142
case AUTO:
@@ -50,8 +51,12 @@ public static IDagLinearizer createDagLinearizer(DagLinearizer type) {
5051
return new LinearizerMinIntermediates();
5152
case PIPELINE_DEPTH_FIRST:
5253
return new LinearizerPipelineAware();
54+
case RESOURCE_AWARE_FAST:
55+
return new LinearizerResourceAwareFast();
56+
case RESOURCE_AWARE_OPTIMAL:
57+
return new LinearizerResourceAwareOptimal();
5358
default:
54-
LOG.warn("Invalid DAG_LINEARIZATION: "+type+", falling back to DEPTH_FIRST ordering");
59+
LOG.warn("Invalid DAG_LINEARIZATION: " + type + ", falling back to DEPTH_FIRST ordering");
5560
return new LinearizerDepthFirst();
5661
}
5762
}
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.lops.compile.linearization;
21+
22+
import org.apache.sysds.lops.Lop;
23+
24+
import java.util.ArrayList;
25+
import java.util.Collections;
26+
import java.util.Comparator;
27+
import java.util.HashSet;
28+
import java.util.Iterator;
29+
import java.util.List;
30+
import java.util.Set;
31+
import java.util.stream.Collectors;
32+
import java.util.stream.IntStream;
33+
import java.util.stream.Stream;
34+
35+
public class LinearizerResourceAwareFast extends IDagLinearizer {
36+
37+
static class Dependency {
38+
int nodeIndex;
39+
int sequenceIndex;
40+
List<Integer> dependencies;
41+
42+
Dependency(int sequenceIndex, int nodeIndex, List<Integer> dependencies) {
43+
this.sequenceIndex = sequenceIndex;
44+
this.nodeIndex = nodeIndex;
45+
this.dependencies = dependencies;
46+
}
47+
48+
public int getSequenceIndex() {
49+
return sequenceIndex;
50+
}
51+
52+
public int getNodeIndex() {
53+
return nodeIndex;
54+
}
55+
56+
public List<Integer> getDependencies() {
57+
return dependencies;
58+
}
59+
}
60+
61+
static class Item {
62+
List<Integer> steps;
63+
List<Integer> current;
64+
Set<Intermediate> intermediates;
65+
double maxMemoryUsage;
66+
67+
Item(List<Integer> steps, List<Integer> current, Set<Intermediate> intermediates, double maxMemoryUsage) {
68+
this.steps = steps;
69+
this.current = current;
70+
this.intermediates = intermediates;
71+
this.maxMemoryUsage = maxMemoryUsage;
72+
}
73+
74+
public List<Integer> getSteps() {
75+
return steps;
76+
}
77+
78+
public List<Integer> getCurrent() {
79+
return current;
80+
}
81+
82+
public double getMaxMemoryUsage() {
83+
return maxMemoryUsage;
84+
}
85+
86+
public Set<Intermediate> getIntermediates() {
87+
return intermediates;
88+
}
89+
}
90+
91+
static class Intermediate {
92+
List<Long> lopIDs;
93+
double memoryUsage;
94+
95+
Intermediate(List<Long> lopIDs, double memoryUsage) {
96+
this.lopIDs = lopIDs;
97+
this.memoryUsage = memoryUsage;
98+
}
99+
100+
void remove(long ID) {
101+
lopIDs.remove(ID);
102+
}
103+
104+
public List<Long> getLopIDs() {
105+
return lopIDs;
106+
}
107+
108+
public double getMemoryUsage() {
109+
return memoryUsage;
110+
}
111+
}
112+
113+
List<Lop> remaining;
114+
115+
@Override
116+
public List<Lop> linearize(List<Lop> dag) {
117+
List<List<Lop>> sequences = new ArrayList<>();
118+
remaining = new ArrayList<>(dag);
119+
120+
List<Lop> outputNodes = remaining.stream().filter(node -> node.getOutputs().isEmpty())
121+
.collect(Collectors.toList());
122+
123+
for(Lop outputNode : outputNodes) {
124+
sequences.add(findSequence(outputNode));
125+
}
126+
127+
while(!remaining.isEmpty()) {
128+
int maxLevel = remaining.stream().mapToInt(Lop::getLevel).max().getAsInt();
129+
Lop node = remaining.stream().filter(n -> n.getLevel() == maxLevel).findFirst().orElseThrow();
130+
sequences.add(findSequence(node));
131+
}
132+
133+
return scheduleSequences(sequences);
134+
}
135+
136+
List<Lop> scheduleSequences(List<List<Lop>> sequences) {
137+
Set<List<Integer>> visited = new HashSet<>();
138+
List<Item> scheduledItems = new ArrayList<>();
139+
140+
Set<Dependency> dependencies = getDependencies(sequences);
141+
List<Integer> sequencesMaxIndex = sequences.stream().map(entry -> entry.size() - 1)
142+
.collect(Collectors.toList());
143+
144+
Item currentItem = new Item(new ArrayList<>(), Collections.nCopies(sequences.size(), -1), new HashSet<>(), 0.0);
145+
146+
while(!currentItem.getCurrent().equals(sequencesMaxIndex)) {
147+
148+
for(int i = 0; i < sequences.size(); i++) {
149+
150+
List<Lop> sequence = sequences.get(i);
151+
152+
if(currentItem.getCurrent().get(i) + 1 < sequence.size()) {
153+
List<Integer> newCurrent = new ArrayList<>(currentItem.getCurrent());
154+
newCurrent.set(i, newCurrent.get(i) + 1);
155+
156+
if(!visited.contains(newCurrent)) {
157+
Set<Dependency> filteredDependencies = dependencies.stream()
158+
.filter(entry -> entry.getNodeIndex() == newCurrent.get(entry.getSequenceIndex()))
159+
.collect(Collectors.toSet());
160+
161+
boolean dependencyIssue = filteredDependencies.parallelStream().anyMatch(
162+
dependency -> IntStream.range(0, newCurrent.size()).anyMatch(
163+
j -> j != dependency.getSequenceIndex() &&
164+
newCurrent.get(j) < dependency.getDependencies().get(j)));
165+
166+
if(!dependencyIssue) {
167+
Set<Intermediate> newIntermediates = new HashSet<>(currentItem.getIntermediates());
168+
169+
Lop nextLop = sequence.get(newCurrent.get(i));
170+
171+
Iterator<Intermediate> intermediateIter = newIntermediates.iterator();
172+
173+
while(intermediateIter.hasNext()) {
174+
Intermediate entry = intermediateIter.next();
175+
entry.remove(nextLop.getID());
176+
if(entry.getLopIDs().isEmpty())
177+
intermediateIter.remove();
178+
}
179+
180+
newIntermediates.add(new Intermediate(
181+
nextLop.getOutputs().stream().map(Lop::getID).collect(Collectors.toList()),
182+
nextLop.getOutputMemoryEstimate()));
183+
184+
List<Integer> newSteps = new ArrayList<>(currentItem.getSteps());
185+
newSteps.add(i);
186+
187+
double mem = newIntermediates.stream().map(Intermediate::getMemoryUsage)
188+
.reduce((double) 0, Double::sum);
189+
190+
Item newItem = new Item(newSteps, newCurrent, newIntermediates,
191+
Math.max(mem, currentItem.getMaxMemoryUsage()));
192+
193+
int index = Collections.binarySearch(scheduledItems, newItem,
194+
Comparator.comparing(Item::getMaxMemoryUsage));
195+
196+
if(index < 0) {
197+
index = -index - 1;
198+
}
199+
200+
scheduledItems.add(index, newItem);
201+
}
202+
visited.add(newCurrent);
203+
}
204+
}
205+
}
206+
207+
currentItem = scheduledItems.remove(0);
208+
}
209+
210+
return walkPath(sequences, currentItem.getSteps());
211+
}
212+
213+
List<Lop> walkPath(List<List<Lop>> sequences, List<Integer> path) {
214+
Iterator<Integer> iterator = path.iterator();
215+
List<Lop> sequence = new ArrayList<>();
216+
217+
while(iterator.hasNext()) {
218+
sequence.add(sequences.get(iterator.next()).remove(0));
219+
}
220+
221+
return sequence;
222+
}
223+
224+
List<Lop> findSequence(Lop startNode) {
225+
List<Lop> sequence = new ArrayList<>();
226+
Lop currentNode = startNode;
227+
sequence.add(currentNode);
228+
remaining.remove(currentNode);
229+
230+
while(currentNode.getInputs().size() == 1) {
231+
if(remaining.contains(currentNode.getInput(0))) {
232+
currentNode = currentNode.getInput(0);
233+
sequence.add(currentNode);
234+
remaining.remove(currentNode);
235+
}
236+
else {
237+
Collections.reverse(sequence);
238+
return sequence;
239+
}
240+
}
241+
242+
Collections.reverse(sequence);
243+
244+
List<Lop> children = currentNode.getInputs();
245+
246+
if(children.isEmpty()) {
247+
return sequence;
248+
}
249+
250+
List<List<Lop>> childSequences = new ArrayList<>();
251+
252+
for(Lop child : children) {
253+
if(remaining.contains(child)) {
254+
childSequences.add(findSequence(child));
255+
}
256+
}
257+
258+
List<Lop> finalSequence = scheduleSequences(childSequences);
259+
260+
return Stream.concat(finalSequence.stream(), sequence.stream()).collect(Collectors.toList());
261+
}
262+
263+
Set<Dependency> getDependencies(List<List<Lop>> sequences) {
264+
Set<Dependency> dependencies = new HashSet<>();
265+
266+
// Get IDs of each Lop in each sequence for faster lookup
267+
List<List<Long>> sequencesLopIDs = sequences.stream()
268+
.map(sequence -> sequence.stream().map(Lop::getID).collect(Collectors.toList()))
269+
.collect(Collectors.toList());
270+
271+
int lastSequenceWithOutput = -1;
272+
273+
// Go through each sequence and check for dependencies
274+
for(int j = 0; j < sequences.size(); j++) {
275+
List<Lop> sequence = sequences.get(j);
276+
int sequenceSize = sequence.size();
277+
int sequenceIndex = j;
278+
279+
// Check if the current sequence depends on other sequences
280+
sequence.get(0).getInputs().forEach(input -> {
281+
long inputID = input.getID();
282+
List<Integer> dependencyIndices = sequencesLopIDs.stream()
283+
.map(list -> list.contains(inputID) ? list.indexOf(inputID) : -1).collect(Collectors.toList());
284+
285+
dependencies.add(new Dependency(sequenceIndex, 0, dependencyIndices));
286+
});
287+
288+
// Check for Lops that depends on Lops from other sequences
289+
for(int k = 0; k < sequenceSize; k++) {
290+
int finalK = k;
291+
int finalJ = j;
292+
sequence.get(k).getInputs().forEach(input -> {
293+
long inputID = input.getID();
294+
if(!sequencesLopIDs.get(finalJ).contains(inputID)) {
295+
List<Integer> dependencyIndices = sequencesLopIDs.stream()
296+
.map(list -> list.contains(inputID) ? list.indexOf(inputID) : -1)
297+
.collect(Collectors.toList());
298+
299+
dependencies.add(new Dependency(finalJ, finalK, dependencyIndices));
300+
}
301+
});
302+
}
303+
304+
// Dependency chain between output Lops so that the outputs are in the correct order
305+
if(sequence.get(sequenceSize - 1).getOutputs().isEmpty()) {
306+
if(lastSequenceWithOutput != -1) {
307+
List<Integer> dependencyList = new ArrayList<>(Collections.nCopies(sequences.size(), -1));
308+
dependencyList.set(lastSequenceWithOutput,
309+
sequences.get(lastSequenceWithOutput).size() - 1);
310+
dependencies.add(new Dependency(j, sequenceSize - 1, dependencyList));
311+
}
312+
lastSequenceWithOutput = j;
313+
}
314+
}
315+
316+
return dependencies;
317+
}
318+
}

0 commit comments

Comments
 (0)