Skip to content

Commit 03980ff

Browse files
author
David Baker Effendi
authored
[ruby] Rework In Pattern Match Overtainting (#5227)
* [ruby] Rework In Pattern Match Overtainting This change reworks the means of extracting the variables from an array pattern match by using an index access corresponding to the location we want to extract. e.g. ```ruby case [type, location] in [:value, result] puts "#{result}" else puts "else" end ``` Accesses `result` by `<tmp-0>[1]` (where `<tmp-0> = [type, location]`) * Formatting
1 parent 457d8d8 commit 03980ff

File tree

4 files changed

+23
-32
lines changed

4 files changed

+23
-32
lines changed

joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
1111
ForExpression,
1212
IfExpression,
1313
InClause,
14+
IndexAccess,
1415
MatchVariable,
1516
MemberCall,
1617
NextExpression,
@@ -23,6 +24,7 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
2324
SingleAssignment,
2425
SplattingRubyNode,
2526
StatementList,
27+
StaticLiteral,
2628
UnaryExpression,
2729
Unknown,
2830
UnlessExpression,
@@ -299,18 +301,22 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo
299301
case x: ArrayPattern =>
300302
val condition = expr.map(e => BinaryExpression(x, "===", e)(x.span)).getOrElse(inClause.pattern)
301303
val body = inClause.body
302-
val variables = x.children.collect { case x: MatchVariable => x }
303304

304-
val conditionBody = if (variables.nonEmpty && expr.isDefined) {
305-
StatementList(variables.map { lhs =>
306-
SingleAssignment(lhs, "=", MatchVariable()(expr.get.span))(
307-
inClause.span
308-
.spanStart(s"${lhs.span.text} = ${RubyOperators.arrayPatternMatch}(${expr.get.text})")
305+
val stmts = x.children.zipWithIndex.flatMap {
306+
case (lhs: MatchVariable, idx) if expr.isDefined =>
307+
val arrAccess = {
308+
val code = s"${expr.get.text}[$idx]"
309+
val base = expr.get.copy()(expr.get.span.spanStart(expr.get.text))
310+
val indices = StaticLiteral(idx.toString)(expr.get.span.spanStart(idx.toString)) :: Nil
311+
IndexAccess(base, indices)(lhs.span.spanStart(code))
312+
}
313+
val asgn = SingleAssignment(lhs, "=", arrAccess)(
314+
inClause.span.spanStart(s"${lhs.span.text} = ${expr.get.text}[$idx]")
309315
)
310-
} :+ body)(body.span)
311-
} else {
312-
body
313-
}
316+
Option(asgn)
317+
case _ => None
318+
} :+ body
319+
val conditionBody = StatementList(stmts)(body.span)
314320

315321
(condition, conditionBody)
316322
case x =>

joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
6262
case node: ReturnExpression => astForReturnExpression(node)
6363
case node: AccessModifier => astForSimpleIdentifier(node.toSimpleIdentifier)
6464
case node: ArrayPattern => astForArrayPattern(node)
65-
case node: MatchVariable => astForMatchVariable(node)
6665
case node: DummyNode => Ast(node.node)
6766
case node: Unknown => astForUnknown(node)
6867
case x =>
@@ -624,27 +623,14 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
624623
val callNode_ =
625624
callNode(node, code(node), Operators.arrayInitializer, Operators.arrayInitializer, DispatchTypes.STATIC_DISPATCH)
626625
val childrenAst = node.children.map {
627-
case x: MatchVariable => astForExpression(SimpleIdentifier()(x.span))
628-
case x => astForExpression(x)
626+
case x: MatchVariable if scope.lookupVariable(x.text).isEmpty => handleVariableOccurrence(x.toSimpleIdentifier)
627+
case x: MatchVariable => astForExpression(x.toSimpleIdentifier)
628+
case x => astForExpression(x)
629629
}
630630

631631
callAst(callNode_, childrenAst)
632632
}
633633

634-
protected def astForMatchVariable(node: MatchVariable): Ast = {
635-
val nodeCode = shortenCode(s"${RubyOperators.arrayPatternMatch}(${node.span.text})")
636-
val callNode_ = callNode(
637-
node,
638-
nodeCode,
639-
RubyOperators.arrayPatternMatch,
640-
RubyOperators.arrayPatternMatch,
641-
DispatchTypes.STATIC_DISPATCH
642-
)
643-
val identAst = astForExpression(SimpleIdentifier()(node.span))
644-
645-
callAst(callNode_, identAst :: Nil)
646-
}
647-
648634
protected def astForMandatoryParameter(node: RubyExpression): Ast = handleVariableOccurrence(node)
649635

650636
protected def astForSimpleCall(node: SimpleCall): Ast = {

joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ object Defines {
3939
val hashInitializer = "<operator>.hashInitializer"
4040
val association = "<operator>.association"
4141
val splat = "<operator>.splat"
42-
val arrayPatternMatch = "<operator>.arrayPatternMatch"
4342
val regexpMatch = "=~"
4443
val regexpNotMatch = "!~"
4544
}

joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,17 @@ class CaseTests extends RubyCode2CpgFixture {
201201
case (lhs: Identifier) :: (rhs: Call) :: Nil =>
202202
lhs.name shouldBe "result"
203203

204-
rhs.methodFullName shouldBe RubyOperators.arrayPatternMatch
205-
rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(<tmp-0>)"
204+
rhs.methodFullName shouldBe Operators.indexAccess
205+
rhs.code shouldBe s"<tmp-0>[1]"
206206
case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]")
207207
}
208208

209209
inside(notResultMatchAssignment.argument.l) {
210210
case (lhs: Identifier) :: (rhs: Call) :: Nil =>
211211
lhs.name shouldBe "notResult"
212212

213-
rhs.methodFullName shouldBe RubyOperators.arrayPatternMatch
214-
rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(<tmp-0>)"
213+
rhs.methodFullName shouldBe Operators.indexAccess
214+
rhs.code shouldBe s"<tmp-0>[1]"
215215
case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]")
216216
}
217217
case _ => fail(s"Expected two true branches")

0 commit comments

Comments
 (0)