Skip to content

Commit 8f96a00

Browse files
authored
Merge pull request #547 from MadKat13/MadKat13-segment-tree-java
Added segment tree in java
2 parents 5c95791 + 593aab8 commit 8f96a00

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
public class SegmentTree {
2+
3+
private long tree[];
4+
private int length = 0;
5+
6+
//build the tree with new array size: 2*2^(log2(n))-1
7+
public SegmentTree(long[] arr){
8+
length = arr.length-1;
9+
int x = (int) (Math.ceil(Math.log(arr.length) / Math.log(2)));
10+
int max_size = 2 * (int) Math.pow(2, x) - 1;
11+
tree = new long[max_size];
12+
buildTree(arr, 0, length, 0);
13+
}
14+
15+
public long getSum(int start, int end){
16+
//input validation
17+
if(start<0 || end > length || end<start){
18+
System.out.println("Invalid Input");
19+
return -1;
20+
}
21+
return getSum(0,length,start,end,0);
22+
}
23+
24+
//recursive walking and summing in the tree
25+
private long getSum(int segstart,int segend,int rangestart,int rangeend,int position){
26+
//complete segment is in the range
27+
if(rangestart<=segstart && rangeend>=segend){
28+
return tree[position];
29+
}
30+
31+
//segment is not in the range
32+
if(rangestart>segend || rangeend<segstart){
33+
return 0;
34+
}
35+
36+
//segment is partially in the range
37+
int mid = segstart + (segend-segstart)/2;
38+
39+
return getSum(segstart,mid,rangestart,rangeend,position*2+1) +
40+
getSum(mid + 1,segend,rangestart,rangeend,position*2+2);
41+
}
42+
43+
//recursive build of the tree in an array with the corresponding sum-values
44+
private long buildTree(long[] arr, int segstart, int segend, int nodeposition){
45+
//smallest segment (leaf)
46+
if(segstart==segend){
47+
tree[nodeposition] = arr[segstart];
48+
return arr[segstart];
49+
}
50+
51+
int mid = segstart + (segend-segstart)/2;
52+
tree[nodeposition] = buildTree(arr,segstart,mid,nodeposition*2+1) +
53+
buildTree(arr,mid+1,segend, nodeposition*2+2);
54+
return tree[nodeposition];
55+
}
56+
57+
58+
public void updateValue(long[] arr, int position, long value){
59+
//input validation
60+
if (position < 0 || position > length - 1) {
61+
System.out.println("Invalid Input");
62+
return;
63+
}
64+
long diff = value - arr[position];
65+
arr[position] = value;
66+
67+
updateValue(0,length,position,diff,0);
68+
}
69+
70+
//recursive update all connected node to the value
71+
private void updateValue(int segstart, int segend, int position, long diff, int node){
72+
//actual position not in segment
73+
if (position < segstart || position > segend)
74+
return;
75+
76+
tree[node] = tree[node] + diff;
77+
if(segend != segstart){
78+
int mid = segstart + (segend-segstart) / 2;
79+
updateValue(segstart,mid,position,diff,2*node+1);
80+
updateValue(mid+1,segend,position,diff,2*node+2);
81+
}
82+
}
83+
84+
public static void main(String args[])
85+
{
86+
long arr[] = {1, 3, 4, 7, 11, 18};
87+
SegmentTree tree = new SegmentTree(arr);
88+
89+
System.out.println("Sum of values in given range = " +
90+
tree.getSum(1, 3));
91+
92+
tree.updateValue(arr, 1, 10);
93+
94+
System.out.println("Updated sum of values in given range = " +
95+
tree.getSum(1, 3));
96+
}
97+
98+
}

0 commit comments

Comments
 (0)