Skip to content

Commit 2734377

Browse files
committed
Python: Add API graph support for parameter annotations
Adds API graph support for observing that in ```python def foo(x : Bar): ... ``` The variable `x` is likely to be an instance of the type `Bar` inside this function. In particular, we add `getInstanceFromAnnotation` as a predicate on API graph nodes that tracks this step (corresponding to a new edge type labeled with "annotation" in the API graph), and extend the existing `getAnInstance` predicate to also include instances arising from type annotations. A more complete solution would also add support for annotated assignments (`x : Foo = ...` or just `x : Foo`) as well as track types through type aliases (`type Foo = Bar`). This turns out to be non-trivial, however, as these type constructs don't have any CFG nodes (and so no data-flow nodes by default either). In order to not have perfect be the enemy of good, this commit is only targeting the type parameter case (which is also likely to be the most common use case anyway). The tests for API graphs have been extended accordingly, including tests for the kinds of type ascriptions that we _don't_ currently model in API graphs (marked with `MISSING:` in the inline tests).
1 parent 047e974 commit 2734377

File tree

4 files changed

+68
-1
lines changed

4 files changed

+68
-1
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
category: feature
3+
---
4+
5+
- Added support for parameter annotations in API graphs. This means that in a function definition such as `def foo(x: Bar): ...`, you can now use the `getInstanceFromAnnotation()` method to step from `Bar` to `x`. In addition to this, the `getAnInstance` method now also includes instances arising from parameter annotations.

python/ql/lib/semmle/python/ApiGraphs.qll

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ module API {
195195
*/
196196
Node getReturn() { result = this.getASuccessor(Label::return()) }
197197

198+
/**
199+
* Gets a node representing instances of the class represented by this node, as specified via
200+
* type annotations.
201+
*/
202+
Node getInstanceFromAnnotation() { result = this.getASuccessor(Label::annotation()) }
203+
198204
/**
199205
* Gets a node representing the `i`th parameter of the function represented by this node.
200206
*
@@ -229,7 +235,9 @@ module API {
229235
/**
230236
* Gets a node representing an instance of the class (or a transitive subclass of the class) represented by this node.
231237
*/
232-
Node getAnInstance() { result = this.getASubclass*().getReturn() }
238+
Node getAnInstance() {
239+
result in [this.getASubclass*().getReturn(), this.getASubclass*().getInstanceFromAnnotation()]
240+
}
233241

234242
/**
235243
* Gets a node representing the result from awaiting this node.
@@ -834,6 +842,10 @@ module API {
834842
lbl = Label::return() and
835843
ref = pred.getACall()
836844
or
845+
// Getting an instance via a type annotation
846+
lbl = Label::annotation() and
847+
ref = pred.getAnAnnotatedInstance()
848+
or
837849
// Awaiting a node that is a use of `base`
838850
lbl = Label::await() and
839851
ref = pred.getAnAwaited()
@@ -1079,6 +1091,7 @@ module API {
10791091
} or
10801092
MkLabelSelfParameter() or
10811093
MkLabelReturn() or
1094+
MkLabelAnnotation() or
10821095
MkLabelSubclass() or
10831096
MkLabelAwait() or
10841097
MkLabelSubscript() or
@@ -1148,6 +1161,11 @@ module API {
11481161
override string toString() { result = "getReturn()" }
11491162
}
11501163

1164+
/** A label for annotations. */
1165+
class LabelAnnotation extends ApiLabel, MkLabelAnnotation {
1166+
override string toString() { result = "getAnnotatedInstance()" }
1167+
}
1168+
11511169
/** A label that gets the subclass of a class. */
11521170
class LabelSubclass extends ApiLabel, MkLabelSubclass {
11531171
override string toString() { result = "getASubclass()" }
@@ -1207,6 +1225,9 @@ module API {
12071225
/** Gets the `return` edge label. */
12081226
LabelReturn return() { any() }
12091227

1228+
/** Gets the `annotation` edge label. */
1229+
LabelAnnotation annotation() { any() }
1230+
12101231
/** Gets the `subclass` edge label. */
12111232
LabelSubclass subclass() { any() }
12121233

python/ql/lib/semmle/python/dataflow/new/internal/LocalSources.qll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ class LocalSourceNode extends Node {
119119
*/
120120
CallCfgNode getACall() { Cached::call(this, result) }
121121

122+
/**
123+
* Gets a node that has this node as its annotation.
124+
*/
125+
Node getAnAnnotatedInstance() { Cached::annotatedInstance(this, result) }
126+
122127
/**
123128
* Gets an awaited value from this node.
124129
*/
@@ -275,6 +280,17 @@ private module Cached {
275280
)
276281
}
277282

283+
cached
284+
predicate annotatedInstance(LocalSourceNode node, Node instance) {
285+
exists(ExprNode n | node.flowsTo(n) |
286+
instance.asCfgNode().getNode() =
287+
any(AnnAssign ann | ann.getAnnotation() = n.asExpr()).getTarget()
288+
or
289+
instance.asCfgNode().getNode() =
290+
any(Parameter p | p.getAnnotation() = n.asCfgNode().getNode())
291+
)
292+
}
293+
278294
/**
279295
* Holds if `node` flows to a value that, when awaited, results in `awaited`.
280296
*/
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from types import AssignmentAnnotation, ParameterAnnotation
2+
3+
def test_annotated_assignment():
4+
local_x : AssignmentAnnotation = create_x() #$ MISSING: use=moduleImport("types").getMember("AssignmentAnnotation")
5+
local_x #$ MISSING: use=moduleImport("types").getMember("AssignmentAnnotation").getAnnotatedInstance()
6+
7+
global_x : AssignmentAnnotation #$ use=moduleImport("types").getMember("AssignmentAnnotation")
8+
global_x #$ MISSING: use=moduleImport("types").getMember("AssignmentAnnotation").getAnnotatedInstance()
9+
10+
def test_parameter_annotation(parameter_y: ParameterAnnotation): #$ use=moduleImport("types").getMember("ParameterAnnotation")
11+
parameter_y #$ use=moduleImport("types").getMember("ParameterAnnotation").getAnnotatedInstance()
12+
13+
type Alias = AssignmentAnnotation
14+
15+
global_z : Alias #$ MISSING: use=moduleImport("types").getMember("AssignmentAnnotation")
16+
global_z #$ MISSING: use=moduleImport("types").getMember("AssignmentAnnotation").getAnnotatedInstance()
17+
18+
def test_parameter_alias(parameter_z: Alias): #$ MISSING: use=moduleImport("types").getMember("AssignmentAnnotation")
19+
parameter_z #$ MISSING: use=moduleImport("types").getMember("AssignmentAnnotation").getAnnotatedInstance()
20+
21+
# local type aliases
22+
def test_local_type_alias():
23+
type LocalAlias = AssignmentAnnotation
24+
local_alias : LocalAlias = create_value() #$ MISSING: use=moduleImport("types").getMember("AssignmentAnnotation")
25+
local_alias #$ MISSING: use=moduleImport("types").getMember("AssignmentAnnotation").getAnnotatedInstance()

0 commit comments

Comments
 (0)