Skip to content

Commit 240a2c2

Browse files
committed
feat: Complete implementation and testing of segment tree
1 parent 920230e commit 240a2c2

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Segment Tree

src/main/java/dataStructures/segmentTree/SegmentTree.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
package dataStructures.segmentTree;
22

3+
/**
4+
* Implementation of a Segment Tree. Uses SegmentTreeNode as a helper node class.
5+
*/
36
public class SegmentTree {
47
private SegmentTreeNode root;
58
private int[] array;
69

10+
/**
11+
* Helper node class. Used internally.
12+
*/
713
private class SegmentTreeNode {
814
private SegmentTreeNode leftChild; // left child
915
private SegmentTreeNode rightChild; // right child
@@ -21,13 +27,17 @@ private class SegmentTreeNode {
2127
*/
2228
public SegmentTreeNode(SegmentTreeNode leftChild, SegmentTreeNode rightChild, int start, int end, int sum) {
2329
this.leftChild = leftChild;
24-
this.rightChild = rightChild
30+
this.rightChild = rightChild;
2531
this.start = start;
2632
this.end = end;
2733
this.sum = sum;
2834
}
2935
}
3036

37+
/**
38+
* Constructor.
39+
* @param nums
40+
*/
3141
public SegmentTree(int[] nums) {
3242
root = buildTree(nums, 0, nums.length - 1);
3343
array = nums;
@@ -37,7 +47,7 @@ private SegmentTreeNode buildTree(int[] nums, int start, int end) {
3747
if (start == end) {
3848
return new SegmentTreeNode(null, null, start, end, nums[start]);
3949
}
40-
int mid = start + (end-start) / 2;
50+
int mid = start + (end - start) / 2;
4151
SegmentTreeNode left = buildTree(nums, start, mid);
4252
SegmentTreeNode right = buildTree(nums, mid + 1, end);
4353
return new SegmentTreeNode(left, right, start, end, left.sum + right.sum);
@@ -68,10 +78,10 @@ private int query(SegmentTreeNode node, int leftEnd, int rightEnd) {
6878
// poss 2: ^ ^
6979
// poss 3: ^ ^
7080
if (leftEnd <= mid) {
71-
rangeSum += query(node.leftChild, leftEnd, Math.min(rightEnd, mid)); // poss1 / poss2
81+
rangeSum += query(node.leftChild, leftEnd, Math.min(rightEnd, mid)); // poss1 or poss2
7282
}
7383
if (mid + 1 <= rightEnd) {
74-
rangeSum += query(node.rightChild, Math.max(leftEnd, mid + 1), rightEnd); // poss2 / poss2
84+
rangeSum += query(node.rightChild, Math.max(leftEnd, mid + 1), rightEnd); // poss2 or poss3
7585
}
7686
return rangeSum;
7787
}
@@ -91,13 +101,14 @@ public void update(int idx, int val) {
91101
private void update(SegmentTreeNode node, int idx, int val) {
92102
if (node.start == node.end && node.start == idx) {
93103
node.sum = val; // previously, node held a single value; now updated
104+
return;
94105
}
95106
int mid = node.start + (node.end - node.start) / 2;
96107
if (idx <= mid) {
97108
update(node.leftChild, idx, val);
98109
} else {
99110
update(node.rightChild, idx, val);
100111
}
101-
node.sum = node.leftChild.sum + node.rightChild.sum; // propagate updates up
112+
node.sum = node.leftChild.sum + node.rightChild.sum; // propagate updates up
102113
}
103114
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package dataStructures.segmentTree;
2+
import static org.junit.Assert.assertEquals;
3+
4+
import org.junit.Test;
5+
6+
public class SegmentTreeTest {
7+
@Test
8+
public void construct_shouldConstructSegmentTree() {
9+
int[] arr1 = new int[] {7, 77, 37, 67, 33, 73, 13, 2, 7, 17, 87, 53};
10+
SegmentTree tree1 = new SegmentTree(arr1);
11+
assertEquals(arr1[1] + arr1[2] + arr1[3], tree1.query(1, 3));
12+
assertEquals(arr1[4] + arr1[5] + arr1[6] + arr1[7], tree1.query(4, 7));
13+
int sum1 = 0;
14+
for (int i = 0; i < arr1.length; i++) {
15+
sum1 += arr1[i];
16+
}
17+
assertEquals(sum1, tree1.query(0, arr1.length - 1));
18+
19+
20+
int[] arr2 = new int[] {7, -77, 37, 67, -33, 0, 73, -13, 2, -7, 17, 0, -87, 53, 0}; // some negatives and 0s
21+
SegmentTree tree2 = new SegmentTree(arr1);
22+
assertEquals(arr1[1] + arr1[2] + arr1[3], tree2.query(1, 3));
23+
assertEquals(arr1[4] + arr1[5] + arr1[6] + arr1[7], tree2.query(4, 7));
24+
int sum2 = 0;
25+
for (int i = 0; i < arr1.length; i++) {
26+
sum2 += arr1[i];
27+
}
28+
assertEquals(sum2, tree2.query(0, arr1.length - 1));
29+
}
30+
31+
@Test
32+
public void update_shouldUpdateSegmentTree() {
33+
int[] arr = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
34+
SegmentTree tree = new SegmentTree(arr);
35+
assertEquals(55, tree.query(0, 10));
36+
tree.update(5, 55);
37+
assertEquals(105, tree.query(0, 10));
38+
}
39+
}

0 commit comments

Comments
 (0)