Skip to content

Commit b85d1b4

Browse files
committed
[SYSTEMDS-3834] Dependency tasks enable via Log4j
This commit reenable dependency graph printing for parallel transform encode. Previously one would have to recompile the system with modified code, while now we can enable the task graph via logging configurations. Closes #2222
1 parent a06ea0f commit b85d1b4

File tree

3 files changed

+108
-6
lines changed

3 files changed

+108
-6
lines changed

src/main/java/org/apache/sysds/runtime/util/DependencyTask.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
import org.apache.sysds.runtime.DMLRuntimeException;
3131

3232
public class DependencyTask<E> implements Comparable<DependencyTask<?>>, Callable<E> {
33-
public static final boolean ENABLE_DEBUG_DATA = false; // explain task graph
3433
protected static final Log LOG = LogFactory.getLog(DependencyTask.class.getName());
34+
/** debugging dependency tasks only used if LOG.isDebugEnabled */
35+
public List<DependencyTask<?>> _dependencyTasks = null;
3536

3637
private final Callable<E> _task;
3738
protected final List<DependencyTask<?>> _dependantTasks;
38-
public List<DependencyTask<?>> _dependencyTasks = null; // only for debugging
3939
private CompletableFuture<Future<?>> _future;
4040
private int _rdy = 0;
4141
private Integer _priority = 0;

src/main/java/org/apache/sysds/runtime/util/DependencyThreadPool.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public List<Future<Future<?>>> submitAll(List<? extends Callable<?>> tasks,
9292
public List<Object> submitAllAndWait(List<DependencyTask<?>> dtasks)
9393
throws ExecutionException, InterruptedException {
9494
List<Object> res = new ArrayList<>();
95-
if(DependencyTask.ENABLE_DEBUG_DATA) {
95+
if(LOG.isDebugEnabled()) {
9696
if (dtasks != null && dtasks.size() > 0)
9797
explainTaskGraph(dtasks);
9898
}
@@ -172,7 +172,7 @@ public static List<DependencyTask<?>> createDependencyTasks(List<? extends Calla
172172
DependencyTask<?> t = ret.get(i);
173173
for(Callable<?> dep : deps) {
174174
DependencyTask<?> dt = map.get(dep);
175-
if(DependencyTask.ENABLE_DEBUG_DATA) {
175+
if(LOG.isDebugEnabled()) {
176176
t._dependencyTasks = t._dependencyTasks == null ? new ArrayList<>() : t._dependencyTasks;
177177
t._dependencyTasks.add(dt);
178178
}
@@ -226,10 +226,12 @@ public static void explainTaskGraph(List<DependencyTask<?>> tasks) {
226226
sbs[level].append(offsets[level]);
227227
sbs[level].append(entry.getKey().toString()+"\n");
228228
}
229-
System.out.println("EXPlAIN (TASK-GRAPH):");
229+
StringBuilder sb = new StringBuilder("\n");
230+
sb.append("EXPlAIN (TASK-GRAPH):");
230231
for (int i=0; i<sbs.length; i++) {
231-
System.out.println(sbs[i].toString());
232+
sb.append(sbs[i].toString());
232233
}
234+
LOG.debug(sb.toString());
233235

234236
}
235237
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.frame.transform;
21+
22+
import static org.junit.Assert.assertTrue;
23+
import static org.junit.Assert.fail;
24+
25+
import java.util.List;
26+
27+
import org.apache.commons.logging.Log;
28+
import org.apache.commons.logging.LogFactory;
29+
import org.apache.log4j.Level;
30+
import org.apache.log4j.Logger;
31+
import org.apache.log4j.spi.LoggingEvent;
32+
import org.apache.sysds.common.Types.ValueType;
33+
import org.apache.sysds.runtime.frame.data.FrameBlock;
34+
import org.apache.sysds.runtime.transform.encode.CompressedEncode;
35+
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
36+
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
37+
import org.apache.sysds.runtime.util.DependencyTask;
38+
import org.apache.sysds.runtime.util.DependencyThreadPool;
39+
import org.apache.sysds.test.LoggingUtils;
40+
import org.apache.sysds.test.LoggingUtils.TestAppender;
41+
import org.apache.sysds.test.TestUtils;
42+
import org.junit.Test;
43+
44+
public class TransformLogger {
45+
protected static final Log LOG = LogFactory.getLog(TransformLogger.class.getName());
46+
47+
private final FrameBlock data;
48+
49+
public TransformLogger() {
50+
try {
51+
52+
data = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231);
53+
data.setSchema(new ValueType[] {ValueType.INT32});
54+
}
55+
catch(Exception e) {
56+
e.printStackTrace();
57+
fail(e.getMessage());
58+
throw e;
59+
}
60+
}
61+
62+
@Test
63+
public void testDummyCode() {
64+
test("{dummycode:[C1]}");
65+
}
66+
67+
public void test(String spec) {
68+
final TestAppender appender = LoggingUtils.overwrite();
69+
70+
try {
71+
Logger.getLogger(CompressedEncode.class).setLevel(Level.DEBUG);
72+
Logger.getLogger(DependencyThreadPool.class).setLevel(Level.DEBUG);
73+
Logger.getLogger(DependencyTask.class).setLevel(Level.DEBUG);
74+
75+
FrameBlock meta = null;
76+
MultiColumnEncoder encoderNormal = EncoderFactory.createEncoder(spec, data.getColumnNames(),
77+
data.getNumColumns(), meta);
78+
encoderNormal.encode(data, 10);
79+
80+
final List<LoggingEvent> log = LoggingUtils.reinsert(appender);
81+
82+
boolean containsMessage = false;
83+
for(LoggingEvent l : log) {
84+
containsMessage |= l.getMessage().toString().contains("EXPlAIN (TASK-GRAPH):");
85+
}
86+
87+
assertTrue(containsMessage);
88+
89+
}
90+
catch(Exception e) {
91+
e.printStackTrace();
92+
fail(e.getMessage());
93+
}
94+
finally {
95+
LoggingUtils.reinsert(appender);
96+
}
97+
98+
}
99+
100+
}

0 commit comments

Comments
 (0)