Skip to content

Commit 398b2a3

Browse files
committed
Swift: Add more test variants.
1 parent 5496b11 commit 398b2a3

File tree

2 files changed

+151
-28
lines changed

2 files changed

+151
-28
lines changed
Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,58 @@
11

22
// --- stubs ---
33

4+
struct URL
5+
{
6+
init?(string: String) {}
7+
init?(string: String, relativeTo: URL?) {}
8+
}
9+
10+
extension String {
11+
init(contentsOf: URL) throws {
12+
var data = ""
13+
14+
// ...
15+
16+
self.init(data)
17+
}
18+
}
19+
420
public protocol Binding {}
521

622
extension String: Binding {}
723

824
class Statement {
9-
init(_ connection: Connection, _ SQL: String) throws {}
25+
fileprivate let connection: Connection
1026

11-
public func bind(_ values: Binding?...) -> Statement { return Statement() }
12-
public func bind(_ values: [Binding?]) -> Statement { return Statement() }
13-
public func bind(_ values: [String: Binding?]) -> Statement { return Statement() }
27+
init(_ connection: Connection, _ SQL: String) throws { self.connection = connection}
1428

15-
@discardableResult public func run(_ bindings: Binding?...) throws -> Statement { return Statement() }
16-
@discardableResult public func run(_ bindings: [Binding?]) throws -> Statement { return Statement() }
17-
@discardableResult public func run(_ bindings: [String: Binding?]) throws -> Statement { return Statement() }
29+
public func bind(_ values: Binding?...) -> Statement { return Statement(connection, "") }
30+
public func bind(_ values: [Binding?]) -> Statement { return Statement(connection, "") }
31+
public func bind(_ values: [String: Binding?]) -> Statement { return Statement(connection, "") }
1832

19-
public func scalar(_ bindings: Binding?...) throws -> Binding? { return Binding() }
20-
public func scalar(_ bindings: [Binding?]) throws -> Binding? { return Binding() }
21-
public func scalar(_ bindings: [String: Binding?]) throws -> Binding? { return Binding() }
33+
@discardableResult public func run(_ bindings: Binding?...) throws -> Statement { return Statement(connection, "") }
34+
@discardableResult public func run(_ bindings: [Binding?]) throws -> Statement { return Statement(connection, "") }
35+
@discardableResult public func run(_ bindings: [String: Binding?]) throws -> Statement { return Statement(connection, "") }
36+
37+
public func scalar(_ bindings: Binding?...) throws -> Binding? { return nil }
38+
public func scalar(_ bindings: [Binding?]) throws -> Binding? { return nil }
39+
public func scalar(_ bindings: [String: Binding?]) throws -> Binding? { return nil }
2240
}
2341

2442
class Connection {
2543
public func execute(_ SQL: String) throws { }
2644

27-
public func prepare(_ statement: String, _ bindings: Binding?...) throws -> Statement { return Statement() }
28-
public func prepare(_ statement: String, _ bindings: [Binding?]) throws -> Statement { return Statement() }
29-
public func prepare(_ statement: String, _ bindings: [String: Binding?]) throws -> Statement { return Statement() }
45+
public func prepare(_ statement: String, _ bindings: Binding?...) throws -> Statement { return Statement(self, "") }
46+
public func prepare(_ statement: String, _ bindings: [Binding?]) throws -> Statement { return Statement(self, "") }
47+
public func prepare(_ statement: String, _ bindings: [String: Binding?]) throws -> Statement { return Statement(self, "") }
3048

31-
@discardableResult public func run(_ statement: String, _ bindings: Binding?...) throws -> Statement { return Statement() }
32-
@discardableResult public func run(_ statement: String, _ bindings: [Binding?]) throws -> Statement { return Statement() }
33-
@discardableResult public func run(_ statement: String, _ bindings: [String: Binding?]) throws -> Statement { return Statement() }
49+
@discardableResult public func run(_ statement: String, _ bindings: Binding?...) throws -> Statement { return Statement(self, "") }
50+
@discardableResult public func run(_ statement: String, _ bindings: [Binding?]) throws -> Statement { return Statement(self, "") }
51+
@discardableResult public func run(_ statement: String, _ bindings: [String: Binding?]) throws -> Statement { return Statement(self, "") }
3452

35-
public func scalar(_ statement: String, _ bindings: Binding?...) throws -> Binding? { return Binding() }
36-
public func scalar(_ statement: String, _ bindings: [Binding?]) throws -> Binding? { return Binding() }
37-
public func scalar(_ statement: String, _ bindings: [String: Binding?]) throws -> Binding? { return Binding() }
53+
public func scalar(_ statement: String, _ bindings: Binding?...) throws -> Binding? { return nil }
54+
public func scalar(_ statement: String, _ bindings: [Binding?]) throws -> Binding? { return nil }
55+
public func scalar(_ statement: String, _ bindings: [String: Binding?]) throws -> Binding? { return nil }
3856
}
3957

4058
// --- tests ---
@@ -49,7 +67,6 @@ func test_sqlite_swift_api(db: Connection) {
4967
let unsafeQuery3 = "SELECT * FROM users WHERE username='\(remoteString)'"
5068
let safeQuery1 = "SELECT * FROM users WHERE username='\(localString)'"
5169
let safeQuery2 = "SELECT * FROM users WHERE username='\(remoteNumber)'"
52-
let varQuery = "SELECT * FROM users WHERE username=?"
5370

5471
// --- execute ---
5572

@@ -61,6 +78,8 @@ func test_sqlite_swift_api(db: Connection) {
6178

6279
// --- prepared statements ---
6380

81+
let varQuery = "SELECT * FROM users WHERE username=?"
82+
6483
let stmt1 = try db.prepare(unsafeQuery3) // BAD
6584
try stmt1.run()
6685

@@ -70,5 +89,45 @@ func test_sqlite_swift_api(db: Connection) {
7089
let stmt3 = try db.prepare(varQuery, remoteString) // GOOD
7190
try stmt3.run()
7291

73-
// TODO: test all versions of prepare, run, scalar on Connection and Statement
92+
let stmt4 = Statement(db, localString) // GOOD
93+
stmt4.run()
94+
95+
let stmt5 = Statement(db, remoteString) // BAD
96+
stmt5.run()
97+
98+
// --- more variants ---
99+
100+
let stmt6 = try db.prepare(unsafeQuery1, "") // BAD
101+
try stmt6.run()
102+
103+
let stmt7 = try db.prepare(unsafeQuery1, [""]) // BAD
104+
try stmt7.run()
105+
106+
let stmt8 = try db.prepare(unsafeQuery1, ["username": ""]) // BAD
107+
try stmt8.run()
108+
109+
db.run(unsafeQuery1, "") // BAD
110+
111+
db.run(unsafeQuery1, [""]) // BAD
112+
113+
db.run(unsafeQuery1, ["username": ""]) // BAD
114+
115+
db.scalar(unsafeQuery1, "") // BAD
116+
117+
db.scalar(unsafeQuery1, [""]) // BAD
118+
119+
db.scalar(unsafeQuery1, ["username": ""]) // BAD
120+
121+
let stmt9 = try db.prepare(varQuery) // GOOD
122+
stmt9.bind(remoteString)
123+
stmt9.bind([remoteString])
124+
stmt9.bind(["username": remoteString])
125+
try stmt9.run(remoteString)
126+
try stmt9.run([remoteString])
127+
try stmt9.run(["username": remoteString])
128+
try stmt9.scalar(remoteString)
129+
try stmt9.scalar([remoteString])
130+
try stmt9.scalar(["username": remoteString])
131+
132+
Statement(db, remoteString).run() // BAD
74133
}

swift/ql/test/query-tests/Security/CWE-089/sqlite3_c_api.swift

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,34 @@ struct URL
77
init?(string: String, relativeTo: URL?) {}
88
}
99

10+
struct Data {
11+
init<S>(_ elements: S) { count = 0 }
12+
13+
var count: Int
14+
15+
func copyBytes(to pointer: UnsafeMutablePointer<UInt8>, count: Int) {}
16+
}
17+
1018
extension String {
1119
init(contentsOf: URL) throws {
12-
var data = ""
20+
var data = ""
21+
22+
// ...
23+
24+
self.init(data)
25+
}
26+
27+
struct Encoding {
28+
var rawValue: UInt
1329

14-
// ...
30+
init(rawValue: UInt) {
31+
self.rawValue = rawValue
32+
}
33+
34+
static let utf16 = Encoding(rawValue: 1)
35+
}
1536

16-
self.init(data)
17-
}
37+
func data(using encoding: String.Encoding, allowLossyConversion: Bool = false) -> Data? { return nil }
1838
}
1939

2040
var SQLITE_OK : Int32 = 0
@@ -97,7 +117,7 @@ func sqlite3_finalize(
97117

98118
// --- tests ---
99119

100-
func test_sqlite3_c_api(db: OpaquePointer?) {
120+
func test_sqlite3_c_api(db: OpaquePointer?, buffer: UnsafeMutablePointer<UInt8>) {
101121
let localString = "user"
102122
let remoteString = try! String(contentsOf: URL(string: "http://example.com/")!)
103123
let remoteNumber = Int(remoteString) ?? 0
@@ -107,7 +127,6 @@ func test_sqlite3_c_api(db: OpaquePointer?) {
107127
let unsafeQuery3 = "SELECT * FROM users WHERE username='\(remoteString)'"
108128
let safeQuery1 = "SELECT * FROM users WHERE username='\(localString)'"
109129
let safeQuery2 = "SELECT * FROM users WHERE username='\(remoteNumber)'"
110-
let varQuery = "SELECT * FROM users WHERE username=?"
111130

112131
// --- exec ---
113132

@@ -119,6 +138,8 @@ func test_sqlite3_c_api(db: OpaquePointer?) {
119138

120139
// --- prepared statements ---
121140

141+
let varQuery = "SELECT * FROM users WHERE username=?"
142+
122143
var stmt1: OpaquePointer?
123144

124145
if (sqlite3_prepare(db, unsafeQuery3, -1, &stmt1, nil) == SQLITE_OK) { // BAD
@@ -147,5 +168,48 @@ func test_sqlite3_c_api(db: OpaquePointer?) {
147168
}
148169
sqlite3_finalize(stmt3)
149170

150-
// TODO: use all versions v3, 16 etc.
171+
// --- variant 'prepare' functions ---
172+
173+
var stmt4: OpaquePointer?
174+
175+
if (sqlite3_prepare_v2(db, unsafeQuery3, -1, &stmt4, nil) == SQLITE_OK) { // BAD
176+
let result = sqlite3_step(stmt4)
177+
// ...
178+
}
179+
sqlite3_finalize(stmt4)
180+
181+
var stmt5: OpaquePointer?
182+
183+
if (sqlite3_prepare_v3(db, unsafeQuery3, -1, 0, &stmt5, nil) == SQLITE_OK) { // BAD
184+
let result = sqlite3_step(stmt5)
185+
// ...
186+
}
187+
sqlite3_finalize(stmt5)
188+
189+
let data = unsafeQuery3.data(using:String.Encoding.utf16)!
190+
data.copyBytes(to: buffer, count: data.count)
191+
192+
var stmt6: OpaquePointer?
193+
194+
if (sqlite3_prepare16(db, buffer, Int32(data.count), &stmt6, nil) == SQLITE_OK) { // BAD
195+
let result = sqlite3_step(stmt6)
196+
// ...
197+
}
198+
sqlite3_finalize(stmt6)
199+
200+
var stmt7: OpaquePointer?
201+
202+
if (sqlite3_prepare16_v2(db, buffer, Int32(data.count), &stmt7, nil) == SQLITE_OK) { // BAD
203+
let result = sqlite3_step(stmt7)
204+
// ...
205+
}
206+
sqlite3_finalize(stmt7)
207+
208+
var stmt8: OpaquePointer?
209+
210+
if (sqlite3_prepare16_v3(db, buffer, Int32(data.count), 0, &stmt8, nil) == SQLITE_OK) { // BAD
211+
let result = sqlite3_step(stmt8)
212+
// ...
213+
}
214+
sqlite3_finalize(stmt8)
151215
}

0 commit comments

Comments
 (0)