Skip to content

Commit a8cca4b

Browse files
Merge pull request github#6373 from joefarebrother/test-gen-improvements
Java: Test generator improvements
2 parents 0049b8e + 309f0e7 commit a8cca4b

File tree

2 files changed

+156
-97
lines changed

2 files changed

+156
-97
lines changed

java/ql/src/utils/GenerateFlowTestCase.py

Lines changed: 121 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import tempfile
1313

1414
if any(s == "--help" for s in sys.argv):
15-
print("""Usage:
16-
GenerateFlowTestCase.py specsToTest.csv projectPom.xml outdir
15+
print("""Usage:
16+
GenerateFlowTestCase.py specsToTest.csv projectPom.xml outdir [--force]
1717
1818
This generates test cases exercising function model specifications found in specsToTest.csv
1919
producing files Test.java, test.ql and test.expected in outdir.
@@ -22,32 +22,41 @@
2222
Typically this means supplying a skeleton POM <dependencies> section that retrieves whatever jars
2323
contain the needed classes.
2424
25+
If --force is present, existing files may be overwritten.
26+
2527
Requirements: `mvn` and `codeql` should both appear on your path.
2628
2729
After test generation completes, any lines in specsToTest.csv that didn't produce tests are output.
2830
If this happens, check the spelling of class and method names, and the syntax of input and output specifications.
2931
""")
30-
sys.exit(0)
32+
sys.exit(0)
33+
34+
force = False
35+
if "--force" in sys.argv:
36+
sys.argv.remove("--force")
37+
force = True
3138

3239
if len(sys.argv) != 4:
33-
print("Usage: GenerateFlowTestCase.py specsToTest.csv projectPom.xml outdir", file=sys.stderr)
34-
print("specsToTest.csv should contain CSV rows describing method taint-propagation specifications to test", file=sys.stderr)
35-
print("projectPom.xml should import dependencies sufficient to resolve the types used in specsToTest.csv", file=sys.stderr)
36-
sys.exit(1)
40+
print(
41+
"Usage: GenerateFlowTestCase.py specsToTest.csv projectPom.xml outdir [--force]", file=sys.stderr)
42+
print("specsToTest.csv should contain CSV rows describing method taint-propagation specifications to test", file=sys.stderr)
43+
print("projectPom.xml should import dependencies sufficient to resolve the types used in specsToTest.csv", file=sys.stderr)
44+
sys.exit(1)
3745

3846
try:
39-
os.makedirs(sys.argv[3])
47+
os.makedirs(sys.argv[3])
4048
except Exception as e:
41-
if e.errno != errno.EEXIST:
42-
print("Failed to create output directory %s: %s" % (sys.argv[3], e))
43-
sys.exit(1)
49+
if e.errno != errno.EEXIST:
50+
print("Failed to create output directory %s: %s" % (sys.argv[3], e))
51+
sys.exit(1)
4452

4553
resultJava = os.path.join(sys.argv[3], "Test.java")
4654
resultQl = os.path.join(sys.argv[3], "test.ql")
4755

48-
if os.path.exists(resultJava) or os.path.exists(resultQl):
49-
print("Won't overwrite existing files '%s' or '%s'" % (resultJava, resultQl), file = sys.stderr)
50-
sys.exit(1)
56+
if not force and (os.path.exists(resultJava) or os.path.exists(resultQl)):
57+
print("Won't overwrite existing files '%s' or '%s'" %
58+
(resultJava, resultQl), file=sys.stderr)
59+
sys.exit(1)
5160

5261
workDir = tempfile.mkdtemp()
5362

@@ -57,129 +66,159 @@
5766
os.makedirs(projectDir)
5867

5968
try:
60-
shutil.copyfile(sys.argv[2], os.path.join(projectDir, "pom.xml"))
69+
shutil.copyfile(sys.argv[2], os.path.join(projectDir, "pom.xml"))
6170
except Exception as e:
62-
print("Failed to read project POM %s: %s" % (sys.argv[2], e), file = sys.stderr)
63-
sys.exit(1)
71+
print("Failed to read project POM %s: %s" %
72+
(sys.argv[2], e), file=sys.stderr)
73+
sys.exit(1)
6474

6575
commentRegex = re.compile("^\s*(//|#)")
76+
77+
6678
def isComment(s):
67-
return commentRegex.match(s) is not None
79+
return commentRegex.match(s) is not None
80+
6881

6982
try:
70-
with open(sys.argv[1], "r") as f:
71-
specs = [l for l in f if not isComment(l)]
83+
with open(sys.argv[1], "r") as f:
84+
specs = [l for l in f if not isComment(l)]
7285
except Exception as e:
73-
print("Failed to open %s: %s\n" % (sys.argv[1], e))
74-
sys.exit(1)
86+
print("Failed to open %s: %s\n" % (sys.argv[1], e))
87+
sys.exit(1)
7588

7689
projectTestPkgDir = os.path.join(projectDir, "src", "main", "java", "test")
7790
projectTestFile = os.path.join(projectTestPkgDir, "Test.java")
7891

7992
os.makedirs(projectTestPkgDir)
8093

94+
8195
def qualifiedOuterNameFromCsvRow(row):
82-
cells = row.split(";")
83-
if len(cells) < 2:
84-
return None
85-
return cells[0] + "." + cells[1].replace("$", ".")
96+
cells = row.split(";")
97+
if len(cells) < 2:
98+
return None
99+
return cells[0] + "." + cells[1].replace("$", ".")
100+
86101

87102
with open(projectTestFile, "w") as testJava:
88-
testJava.write("package test;\n\npublic class Test {\n\n")
103+
testJava.write("package test;\n\npublic class Test {\n\n")
89104

90-
for i, spec in enumerate(specs):
91-
outerName = qualifiedOuterNameFromCsvRow(spec)
92-
if outerName is None:
93-
print("A taint specification has the wrong format: should be 'package;classname;methodname....'", file = sys.stderr)
94-
print("Mis-formatted row: " + spec, file = sys.stderr)
95-
sys.exit(1)
96-
testJava.write("\t%s obj%d = null;\n" % (outerName, i))
105+
for i, spec in enumerate(specs):
106+
outerName = qualifiedOuterNameFromCsvRow(spec)
107+
if outerName is None:
108+
print("A taint specification has the wrong format: should be 'package;classname;methodname....'", file=sys.stderr)
109+
print("Mis-formatted row: " + spec, file=sys.stderr)
110+
sys.exit(1)
111+
testJava.write("\t%s obj%d = null;\n" % (outerName, i))
97112

98-
testJava.write("}")
113+
testJava.write("}")
99114

100115
print("Creating project database")
101116
cmd = ["codeql", "database", "create", "--language=java", "db"]
102-
ret = subprocess.call(cmd, cwd = projectDir)
117+
ret = subprocess.call(cmd, cwd=projectDir)
103118
if ret != 0:
104-
print("Failed to create project database. Check that '%s' is a valid POM that pulls in all necessary dependencies, and '%s' specifies valid classes and methods." % (sys.argv[2], sys.argv[1]), file = sys.stderr)
105-
print("Failed command was: %s (cwd: %s)" % (shlex.join(cmd), projectDir), file = sys.stderr)
106-
sys.exit(1)
119+
print("Failed to create project database. Check that '%s' is a valid POM that pulls in all necessary dependencies, and '%s' specifies valid classes and methods." % (
120+
sys.argv[2], sys.argv[1]), file=sys.stderr)
121+
print("Failed command was: %s (cwd: %s)" %
122+
(shlex.join(cmd), projectDir), file=sys.stderr)
123+
sys.exit(1)
107124

108125
print("Creating test-generation query")
109126
queryDir = os.path.join(workDir, "query")
110127
os.makedirs(queryDir)
111128
qlFile = os.path.join(queryDir, "gen.ql")
112129
with open(os.path.join(queryDir, "qlpack.yml"), "w") as f:
113-
f.write("name: test-generation-query\nversion: 0.0.0\nlibraryPathDependencies: codeql-java")
130+
f.write("name: test-generation-query\nversion: 0.0.0\nlibraryPathDependencies: codeql-java")
114131
with open(qlFile, "w") as f:
115-
f.write("import java\nimport utils.GenerateFlowTestCase\n\nclass GenRow extends TargetSummaryModelCsv {\n\n\toverride predicate row(string r) {\n\t\tr = [\n")
116-
f.write(",\n".join('\t\t\t"%s"' % spec.strip() for spec in specs))
117-
f.write("\n\t\t]\n\t}\n}\n")
132+
f.write(
133+
"import java\nimport utils.GenerateFlowTestCase\n\nclass GenRow extends TargetSummaryModelCsv {\n\n\toverride predicate row(string r) {\n\t\tr = [\n")
134+
f.write(",\n".join('\t\t\t"%s"' % spec.strip() for spec in specs))
135+
f.write("\n\t\t]\n\t}\n}\n")
118136

119137
print("Generating tests")
120138
generatedBqrs = os.path.join(queryDir, "out.bqrs")
121-
cmd = ['codeql', 'query', 'run', qlFile, '--database', os.path.join(projectDir, "db"), '--output', generatedBqrs]
139+
cmd = ['codeql', 'query', 'run', qlFile, '--database',
140+
os.path.join(projectDir, "db"), '--output', generatedBqrs]
122141
ret = subprocess.call(cmd)
123142
if ret != 0:
124-
print("Failed to generate tests. Failed command was: " + shlex.join(cmd))
125-
sys.exit(1)
143+
print("Failed to generate tests. Failed command was: " + shlex.join(cmd))
144+
sys.exit(1)
126145

127146
generatedJson = os.path.join(queryDir, "out.json")
128-
cmd = ['codeql', 'bqrs', 'decode', generatedBqrs, '--format=json', '--output', generatedJson]
147+
cmd = ['codeql', 'bqrs', 'decode', generatedBqrs,
148+
'--format=json', '--output', generatedJson]
129149
ret = subprocess.call(cmd)
130150
if ret != 0:
131-
print("Failed to decode BQRS. Failed command was: " + shlex.join(cmd))
132-
sys.exit(1)
133-
134-
def getTuples(queryName, jsonResult, fname):
135-
if queryName not in jsonResult or "tuples" not in jsonResult[queryName]:
136-
print("Failed to read generated tests: expected key '%s' with a 'tuples' subkey in file '%s'" % (queryName, fname), file = sys.stderr)
151+
print("Failed to decode BQRS. Failed command was: " + shlex.join(cmd))
137152
sys.exit(1)
138-
return jsonResult[queryName]["tuples"]
139153

140-
with open(generatedJson, "r") as f:
141-
generateOutput = json.load(f)
142-
expectedTables = ("getTestCase", "getASupportMethodModel", "missingSummaryModelCsv", "getAParseFailure")
143154

144-
testCaseRows, supportModelRows, missingSummaryModelCsvRows, parseFailureRows = \
145-
tuple([getTuples(k, generateOutput, generatedJson) for k in expectedTables])
155+
def getTuples(queryName, jsonResult, fname):
156+
if queryName not in jsonResult or "tuples" not in jsonResult[queryName]:
157+
print("Failed to read generated tests: expected key '%s' with a 'tuples' subkey in file '%s'" % (
158+
queryName, fname), file=sys.stderr)
159+
sys.exit(1)
160+
return jsonResult[queryName]["tuples"]
146161

147-
if len(testCaseRows) != 1 or len(testCaseRows[0]) != 1:
148-
print("Expected exactly one getTestCase result with one column (got: %s)" % json.dumps(testCaseRows), file = sys.stderr)
149-
if any(len(row) != 1 for row in supportModelRows):
150-
print("Expected exactly one column in getASupportMethodModel relation (got: %s)" % json.dumps(supportModelRows), file = sys.stderr)
151-
if any(len(row) != 2 for row in parseFailureRows):
152-
print("Expected exactly two columns in parseFailureRows relation (got: %s)" % json.dumps(parseFailureRows), file = sys.stderr)
153162

154-
if len(missingSummaryModelCsvRows) != 0:
155-
print("Tests for some CSV rows were requested that were not in scope (SummaryModelCsv.row does not hold):\n" + "\n".join(r[0] for r in missingSummaryModelCsvRows))
156-
sys.exit(1)
157-
if len(parseFailureRows) != 0:
158-
print("The following rows failed to generate any test case. Check package, class and method name spelling, and argument and result specifications:\n%s" % "\n".join(r[0] + ": " + r[1] for r in parseFailureRows), file = sys.stderr)
159-
sys.exit(1)
163+
with open(generatedJson, "r") as f:
164+
generateOutput = json.load(f)
165+
expectedTables = ("getTestCase", "getASupportMethodModel",
166+
"missingSummaryModelCsv", "getAParseFailure", "noTestCaseGenerated")
167+
168+
testCaseRows, supportModelRows, missingSummaryModelCsvRows, parseFailureRows, noTestCaseGeneratedRows = \
169+
tuple([getTuples(k, generateOutput, generatedJson)
170+
for k in expectedTables])
171+
172+
if len(testCaseRows) != 1 or len(testCaseRows[0]) != 1:
173+
print("Expected exactly one getTestCase result with one column (got: %s)" %
174+
json.dumps(testCaseRows), file=sys.stderr)
175+
if any(len(row) != 1 for row in supportModelRows):
176+
print("Expected exactly one column in getASupportMethodModel relation (got: %s)" %
177+
json.dumps(supportModelRows), file=sys.stderr)
178+
if any(len(row) != 2 for row in parseFailureRows):
179+
print("Expected exactly two columns in parseFailureRows relation (got: %s)" %
180+
json.dumps(parseFailureRows), file=sys.stderr)
181+
if any(len(row) != 1 for row in noTestCaseGeneratedRows):
182+
print("Expected exactly one column in noTestCaseGenerated relation (got: %s)" %
183+
json.dumps(noTestCaseGeneratedRows), file=sys.stderr)
184+
185+
if len(missingSummaryModelCsvRows) != 0:
186+
print("Tests for some CSV rows were requested that were not in scope (SummaryModelCsv.row does not hold):\n" +
187+
"\n".join(r[0] for r in missingSummaryModelCsvRows))
188+
sys.exit(1)
189+
if len(parseFailureRows) != 0:
190+
print("The following rows failed to generate any test case. Check package, class and method name spelling, and argument and result specifications:\n%s" %
191+
"\n".join(r[0] + ": " + r[1] for r in parseFailureRows), file=sys.stderr)
192+
sys.exit(1)
193+
if len(noTestCaseGeneratedRows) != 0:
194+
print("The following CSV rows failed to generate any test case due to a limitation of the query. Other test cases will still be generated:\n" +
195+
"\n".join(r[0] for r in noTestCaseGeneratedRows))
160196

161197
with open(resultJava, "w") as f:
162-
f.write(generateOutput["getTestCase"]["tuples"][0][0])
198+
f.write(generateOutput["getTestCase"]["tuples"][0][0])
163199

164200
scriptPath = os.path.dirname(sys.argv[0])
165201

202+
166203
def copyfile(fromName, toFileHandle):
167-
with open(os.path.join(scriptPath, fromName), "r") as fromFileHandle:
168-
shutil.copyfileobj(fromFileHandle, toFileHandle)
204+
with open(os.path.join(scriptPath, fromName), "r") as fromFileHandle:
205+
shutil.copyfileobj(fromFileHandle, toFileHandle)
206+
169207

170208
with open(resultQl, "w") as f:
171-
copyfile("testHeader.qlfrag", f)
172-
if len(supportModelRows) != 0:
173-
copyfile("testModelsHeader.qlfrag", f)
174-
f.write(", ".join('"%s"' % modelSpecRow[0].strip() for modelSpecRow in supportModelRows))
175-
copyfile("testModelsFooter.qlfrag", f)
176-
copyfile("testFooter.qlfrag", f)
209+
copyfile("testHeader.qlfrag", f)
210+
if len(supportModelRows) != 0:
211+
copyfile("testModelsHeader.qlfrag", f)
212+
f.write(", ".join('"%s"' %
213+
modelSpecRow[0].strip() for modelSpecRow in supportModelRows))
214+
copyfile("testModelsFooter.qlfrag", f)
215+
copyfile("testFooter.qlfrag", f)
177216

178217
# Make an empty .expected file, since this is an inline-exectations test
179218
with open(os.path.join(sys.argv[3], "test.expected"), "w"):
180-
pass
219+
pass
181220

182221
cmd = ['codeql', 'query', 'format', '-qq', '-i', resultQl]
183222
subprocess.call(cmd)
184223

185-
shutil.rmtree(workDir)
224+
shutil.rmtree(workDir)

java/ql/src/utils/GenerateFlowTestCase.qll

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,26 @@ query string getAParseFailure(string reason) {
6262
)
6363
}
6464

65+
/**
66+
* Gets a CSV row for which a test was requested and was correctly parsed,
67+
* but for which no test case could be generated due to a limitation of the query.
68+
*/
69+
query string noTestCaseGenerated() {
70+
any(TargetSummaryModelCsv target).row(result) and
71+
any(SummaryModelCsv model).row(result) and
72+
not exists(getAParseFailure(_)) and
73+
not exists(any(TestCase tc).getATestSnippetForRow(result))
74+
}
75+
6576
private class CallableToTest extends Callable {
6677
CallableToTest() {
6778
exists(
6879
string namespace, string type, boolean subtypes, string name, string signature, string ext
6980
|
7081
summaryModel(namespace, type, subtypes, name, signature, ext, _, _, _) and
71-
this = interpretElement(namespace, type, subtypes, name, signature, ext)
82+
this = interpretElement(namespace, type, subtypes, name, signature, ext) and
83+
this.isPublic() and
84+
getRootType(this.getDeclaringType()).isPublic()
7285
)
7386
}
7487
}
@@ -148,7 +161,10 @@ string contentToken(Content c) {
148161
RefType getRootType(RefType t) {
149162
if t instanceof NestedType
150163
then result = getRootType(t.(NestedType).getEnclosingType())
151-
else result = t
164+
else
165+
if t instanceof Array
166+
then result = getRootType(t.(Array).getElementType())
167+
else result = t
152168
}
153169

154170
/**
@@ -485,18 +501,22 @@ predicate isImportable(Type t) {
485501
* if we cannot import it due to a name clash.
486502
*/
487503
string getShortNameIfPossible(Type t) {
488-
getRootSourceDeclaration(t) = any(TestCase tc).getADesiredImport() and
489-
if t instanceof RefType
490-
then
491-
exists(RefType replaced, string nestedName |
492-
replaced = replaceTypeVariable(t).getSourceDeclaration() and
493-
nestedName = replaced.nestedName().replaceAll("$", ".")
494-
|
495-
if isImportable(getRootSourceDeclaration(t))
496-
then result = nestedName
497-
else result = replaced.getPackage().getName() + "." + nestedName
498-
)
499-
else result = t.getName()
504+
if t instanceof Array
505+
then result = getShortNameIfPossible(t.(Array).getElementType()) + "[]"
506+
else (
507+
getRootSourceDeclaration(t) = any(TestCase tc).getADesiredImport() and
508+
if t instanceof RefType
509+
then
510+
exists(RefType replaced, string nestedName |
511+
replaced = replaceTypeVariable(t).getSourceDeclaration() and
512+
nestedName = replaced.nestedName().replaceAll("$", ".")
513+
|
514+
if isImportable(getRootSourceDeclaration(t))
515+
then result = nestedName
516+
else result = replaced.getPackage().getName() + "." + nestedName
517+
)
518+
else result = t.getName()
519+
)
500520
}
501521

502522
/**
@@ -533,7 +553,7 @@ query string getTestCase() {
533553
result =
534554
"package generatedtest;\n\n" + concat(getAnImportStatement() + "\n") +
535555
"\n// Test case generated by GenerateFlowTestCase.ql\npublic class Test {\n\n" +
536-
concat("\t" + getASupportMethod() + "\n") + "\n\tpublic void test() {\n\n" +
556+
concat("\t" + getASupportMethod() + "\n") + "\n\tpublic void test() throws Exception {\n\n" +
537557
concat(string row, string snippet |
538558
snippet = any(TestCase tc).getATestSnippetForRow(row)
539559
|

0 commit comments

Comments
 (0)